Skip to content

Commit 4bf6f91

Browse files
authored
feat: Propagate error messages from worker failures to Task model (#98)
Signed-off-by: Pawel Rein <pawel.rein@prezi.com>
1 parent bc426e0 commit 4bf6f91

File tree

3 files changed

+196
-0
lines changed

3 files changed

+196
-0
lines changed

docling_jobkit/datamodel/task.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ class Task(BaseModel):
4444
chunking_export_options: ChunkingExportOptions = ChunkingExportOptions()
4545
# scratch_dir: Optional[Path] = None
4646
processing_meta: Optional[TaskProcessingMeta] = None
47+
error_message: Optional[str] = None
4748
created_at: datetime.datetime = Field(
4849
default_factory=partial(datetime.datetime.now, datetime.timezone.utc)
4950
)

docling_jobkit/orchestrators/rq/orchestrator.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class _TaskUpdate(BaseModel):
4141
task_id: str
4242
task_status: TaskStatus
4343
result_key: Optional[str] = None
44+
error_message: Optional[str] = None
4445

4546

4647
class RQOrchestrator(BaseOrchestrator):
@@ -197,6 +198,12 @@ async def _listen_for_updates(self):
197198

198199
# Update the status
199200
task.set_status(data.task_status)
201+
# Store error message on failure
202+
if (
203+
data.task_status == TaskStatus.FAILURE
204+
and data.error_message is not None
205+
):
206+
task.error_message = data.error_message
200207
# Update the results lookup
201208
if (
202209
data.task_status == TaskStatus.SUCCESS
Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
"""Tests for error_message propagation through _TaskUpdate -> Task."""
2+
3+
from unittest.mock import AsyncMock, MagicMock, patch
4+
5+
import pytest
6+
7+
from docling_jobkit.datamodel.task import Task
8+
from docling_jobkit.datamodel.task_meta import TaskStatus
9+
from docling_jobkit.datamodel.task_targets import InBodyTarget
10+
from docling_jobkit.orchestrators.rq.orchestrator import (
11+
RQOrchestrator,
12+
RQOrchestratorConfig,
13+
_TaskUpdate,
14+
)
15+
16+
17+
def _make_pubsub_message(
18+
task_id: str,
19+
status: TaskStatus,
20+
error_message: str | None = None,
21+
result_key: str | None = None,
22+
) -> dict:
23+
update = _TaskUpdate(
24+
task_id=task_id,
25+
task_status=status,
26+
error_message=error_message,
27+
result_key=result_key,
28+
)
29+
return {"type": "message", "data": update.model_dump_json()}
30+
31+
32+
async def _fake_listen(messages):
33+
for msg in messages:
34+
yield msg
35+
36+
37+
def _make_pubsub(messages):
38+
pubsub = MagicMock()
39+
pubsub.subscribe = AsyncMock()
40+
pubsub.listen.return_value = _fake_listen(messages)
41+
return pubsub
42+
43+
44+
def _make_orchestrator_with_task():
45+
config = RQOrchestratorConfig()
46+
with patch.object(RQOrchestrator, "__init__", lambda self, **kw: None):
47+
orch = object.__new__(RQOrchestrator)
48+
orch.config = config
49+
orch.tasks = {}
50+
orch.notifier = None
51+
orch._task_result_keys = {}
52+
orch._async_redis_conn = MagicMock()
53+
54+
task = Task(
55+
task_id="test-task-1",
56+
sources=[],
57+
target=InBodyTarget(),
58+
)
59+
orch.tasks[task.task_id] = task
60+
return orch, task
61+
62+
63+
class TestTaskUpdateErrorMessage:
64+
def test_task_update_with_error_message(self):
65+
update = _TaskUpdate(
66+
task_id="t1",
67+
task_status=TaskStatus.FAILURE,
68+
error_message="conversion failed: corrupt PDF",
69+
)
70+
assert update.error_message == "conversion failed: corrupt PDF"
71+
72+
def test_task_update_without_error_message(self):
73+
update = _TaskUpdate(
74+
task_id="t1",
75+
task_status=TaskStatus.SUCCESS,
76+
)
77+
assert update.error_message is None
78+
79+
def test_task_update_serialization_roundtrip(self):
80+
update = _TaskUpdate(
81+
task_id="t1",
82+
task_status=TaskStatus.FAILURE,
83+
error_message="OOM killed",
84+
)
85+
json_str = update.model_dump_json()
86+
restored = _TaskUpdate.model_validate_json(json_str)
87+
assert restored.error_message == "OOM killed"
88+
assert restored.task_status == TaskStatus.FAILURE
89+
90+
def test_task_update_backward_compatible(self):
91+
json_without_error = '{"task_id": "t1", "task_status": "failure"}'
92+
update = _TaskUpdate.model_validate_json(json_without_error)
93+
assert update.error_message is None
94+
assert update.task_status == TaskStatus.FAILURE
95+
96+
97+
class TestTaskErrorMessage:
98+
def test_task_has_error_message_field(self):
99+
task = Task(task_id="t1", sources=[], target=InBodyTarget())
100+
assert task.error_message is None
101+
102+
def test_task_with_error_message(self):
103+
task = Task(
104+
task_id="t1",
105+
sources=[],
106+
target=InBodyTarget(),
107+
error_message="something broke",
108+
)
109+
assert task.error_message == "something broke"
110+
111+
def test_task_error_message_serialization(self):
112+
task = Task(
113+
task_id="t1",
114+
sources=[],
115+
target=InBodyTarget(),
116+
error_message="timeout after 300s",
117+
)
118+
data = task.model_dump(mode="json", serialize_as_any=True)
119+
assert data["error_message"] == "timeout after 300s"
120+
121+
122+
class TestListenForUpdatesErrorPropagation:
123+
@pytest.mark.asyncio
124+
async def test_failure_with_error_message_sets_task_error(self):
125+
orch, task = _make_orchestrator_with_task()
126+
127+
messages = [
128+
_make_pubsub_message(
129+
task.task_id,
130+
TaskStatus.FAILURE,
131+
error_message="RuntimeError: No converter",
132+
)
133+
]
134+
orch._async_redis_conn.pubsub.return_value = _make_pubsub(messages)
135+
136+
await orch._listen_for_updates()
137+
138+
assert task.task_status == TaskStatus.FAILURE
139+
assert task.error_message == "RuntimeError: No converter"
140+
141+
@pytest.mark.asyncio
142+
async def test_failure_without_error_message_leaves_none(self):
143+
orch, task = _make_orchestrator_with_task()
144+
145+
messages = [_make_pubsub_message(task.task_id, TaskStatus.FAILURE)]
146+
orch._async_redis_conn.pubsub.return_value = _make_pubsub(messages)
147+
148+
await orch._listen_for_updates()
149+
150+
assert task.task_status == TaskStatus.FAILURE
151+
assert task.error_message is None
152+
153+
@pytest.mark.asyncio
154+
async def test_success_does_not_set_error_message(self):
155+
orch, task = _make_orchestrator_with_task()
156+
157+
messages = [
158+
_make_pubsub_message(
159+
task.task_id,
160+
TaskStatus.SUCCESS,
161+
result_key="docling:results:test-task-1",
162+
)
163+
]
164+
orch._async_redis_conn.pubsub.return_value = _make_pubsub(messages)
165+
166+
await orch._listen_for_updates()
167+
168+
assert task.task_status == TaskStatus.SUCCESS
169+
assert task.error_message is None
170+
171+
@pytest.mark.asyncio
172+
async def test_started_then_failure_preserves_error(self):
173+
orch, task = _make_orchestrator_with_task()
174+
175+
messages = [
176+
_make_pubsub_message(task.task_id, TaskStatus.STARTED),
177+
_make_pubsub_message(
178+
task.task_id,
179+
TaskStatus.FAILURE,
180+
error_message="GPU OOM",
181+
),
182+
]
183+
orch._async_redis_conn.pubsub.return_value = _make_pubsub(messages)
184+
185+
await orch._listen_for_updates()
186+
187+
assert task.task_status == TaskStatus.FAILURE
188+
assert task.error_message == "GPU OOM"

0 commit comments

Comments
 (0)