Skip to content

Commit 1553d98

Browse files
maxdmlkraftp
andauthored
Fix queue recovery (#204)
When performing recovery, we now re-enqueue workflows that came from a queue. This allows tasks from a queue to respect the concurrency limits. Re-enqueue = reset the start time and executor assignment in the queue table. This ensures the task is re-inserted in the same position in the queue. --------- Co-authored-by: Peter Kraft <[email protected]>
1 parent 64dde17 commit 1553d98

File tree

5 files changed

+179
-21
lines changed

5 files changed

+179
-21
lines changed

dbos/_dbos.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
)
5757
from ._roles import default_required_roles, required_roles
5858
from ._scheduler import ScheduledWorkflow, scheduled
59-
from ._sys_db import WorkflowStatusString, reset_system_database
59+
from ._sys_db import reset_system_database
6060
from ._tracer import dbos_tracer
6161

6262
if TYPE_CHECKING:
@@ -613,6 +613,7 @@ def get_workflow_status(cls, workflow_id: str) -> Optional[WorkflowStatus]:
613613
workflow_id=workflow_id,
614614
status=stat["status"],
615615
name=stat["name"],
616+
executor_id=stat["executor_id"],
616617
recovery_attempts=stat["recovery_attempts"],
617618
class_name=stat["class_name"],
618619
config_name=stat["config_name"],
@@ -909,6 +910,7 @@ class WorkflowStatus:
909910
workflow_id(str): The ID of the workflow execution
910911
status(str): The status of the execution, from `WorkflowStatusString`
911912
name(str): The workflow function name
913+
executor_id(str): The ID of the executor running the workflow
912914
class_name(str): For member functions, the name of the class containing the workflow function
913915
config_name(str): For instance member functions, the name of the class instance for the execution
914916
queue_name(str): For workflows that are or were queued, the queue name
@@ -922,6 +924,7 @@ class WorkflowStatus:
922924
workflow_id: str
923925
status: str
924926
name: str
927+
executor_id: Optional[str]
925928
class_name: Optional[str]
926929
config_name: Optional[str]
927930
queue_name: Optional[str]

dbos/_recovery.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,29 @@
66

77
from ._core import execute_workflow_by_id
88
from ._error import DBOSWorkflowFunctionNotFoundError
9+
from ._sys_db import GetPendingWorkflowsOutput
910

1011
if TYPE_CHECKING:
1112
from ._dbos import DBOS, WorkflowHandle
1213

1314

14-
def startup_recovery_thread(dbos: "DBOS", workflow_ids: List[str]) -> None:
15+
def startup_recovery_thread(
16+
dbos: "DBOS", pending_workflows: List[GetPendingWorkflowsOutput]
17+
) -> None:
1518
"""Attempt to recover local pending workflows on startup using a background thread."""
1619
stop_event = threading.Event()
1720
dbos.stop_events.append(stop_event)
18-
while not stop_event.is_set() and len(workflow_ids) > 0:
21+
while not stop_event.is_set() and len(pending_workflows) > 0:
1922
try:
20-
for workflowID in list(workflow_ids):
21-
execute_workflow_by_id(dbos, workflowID)
22-
workflow_ids.remove(workflowID)
23+
for pending_workflow in list(pending_workflows):
24+
if (
25+
pending_workflow.queue_name
26+
and pending_workflow.queue_name != "_dbos_internal_queue"
27+
):
28+
dbos._sys_db.clear_queue_assignment(pending_workflow.workflow_uuid)
29+
continue
30+
execute_workflow_by_id(dbos, pending_workflow.workflow_uuid)
31+
pending_workflows.remove(pending_workflow)
2332
except DBOSWorkflowFunctionNotFoundError:
2433
time.sleep(1)
2534
except Exception as e:
@@ -39,12 +48,23 @@ def recover_pending_workflows(
3948
f"Skip local recovery because it's running in a VM: {os.environ.get('DBOS__VMID')}"
4049
)
4150
dbos.logger.debug(f"Recovering pending workflows for executor: {executor_id}")
42-
workflow_ids = dbos._sys_db.get_pending_workflows(executor_id)
43-
dbos.logger.debug(f"Pending workflows: {workflow_ids}")
44-
45-
for workflowID in workflow_ids:
46-
handle = execute_workflow_by_id(dbos, workflowID)
47-
workflow_handles.append(handle)
51+
pending_workflows = dbos._sys_db.get_pending_workflows(executor_id)
52+
for pending_workflow in pending_workflows:
53+
if (
54+
pending_workflow.queue_name
55+
and pending_workflow.queue_name != "_dbos_internal_queue"
56+
):
57+
try:
58+
dbos._sys_db.clear_queue_assignment(pending_workflow.workflow_uuid)
59+
workflow_handles.append(
60+
dbos.retrieve_workflow(pending_workflow.workflow_uuid)
61+
)
62+
except Exception as e:
63+
dbos.logger.error(e)
64+
else:
65+
workflow_handles.append(
66+
execute_workflow_by_id(dbos, pending_workflow.workflow_uuid)
67+
)
4868

4969
dbos.logger.info("Recovered pending workflows")
5070
return workflow_handles

dbos/_sys_db.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,12 @@ def __init__(self, workflow_uuids: List[str]):
140140
self.workflow_uuids = workflow_uuids
141141

142142

143+
class GetPendingWorkflowsOutput:
144+
def __init__(self, *, workflow_uuid: str, queue_name: Optional[str] = None):
145+
self.workflow_uuid: str = workflow_uuid
146+
self.queue_name: Optional[str] = queue_name
147+
148+
143149
class WorkflowInformation(TypedDict, total=False):
144150
workflow_uuid: str
145151
status: WorkflowStatuses # The status of the workflow.
@@ -465,6 +471,7 @@ def get_workflow_status(
465471
SystemSchema.workflow_status.c.authenticated_roles,
466472
SystemSchema.workflow_status.c.assumed_role,
467473
SystemSchema.workflow_status.c.queue_name,
474+
SystemSchema.workflow_status.c.executor_id,
468475
).where(SystemSchema.workflow_status.c.workflow_uuid == workflow_uuid)
469476
).fetchone()
470477
if row is None:
@@ -479,7 +486,7 @@ def get_workflow_status(
479486
"error": None,
480487
"app_id": None,
481488
"app_version": None,
482-
"executor_id": None,
489+
"executor_id": row[10],
483490
"request": row[2],
484491
"recovery_attempts": row[3],
485492
"authenticated_user": row[6],
@@ -746,16 +753,27 @@ def get_queued_workflows(
746753

747754
return GetWorkflowsOutput(workflow_uuids)
748755

749-
def get_pending_workflows(self, executor_id: str) -> list[str]:
756+
def get_pending_workflows(
757+
self, executor_id: str
758+
) -> list[GetPendingWorkflowsOutput]:
750759
with self.engine.begin() as c:
751760
rows = c.execute(
752-
sa.select(SystemSchema.workflow_status.c.workflow_uuid).where(
761+
sa.select(
762+
SystemSchema.workflow_status.c.workflow_uuid,
763+
SystemSchema.workflow_status.c.queue_name,
764+
).where(
753765
SystemSchema.workflow_status.c.status
754766
== WorkflowStatusString.PENDING.value,
755767
SystemSchema.workflow_status.c.executor_id == executor_id,
756768
)
757769
).fetchall()
758-
return [row[0] for row in rows]
770+
return [
771+
GetPendingWorkflowsOutput(
772+
workflow_uuid=row.workflow_uuid,
773+
queue_name=row.queue_name,
774+
)
775+
for row in rows
776+
]
759777

760778
def record_operation_result(
761779
self, result: OperationResultInternal, conn: Optional[sa.Connection] = None
@@ -1375,6 +1393,19 @@ def remove_from_queue(self, workflow_id: str, queue: "Queue") -> None:
13751393
.values(completed_at_epoch_ms=int(time.time() * 1000))
13761394
)
13771395

1396+
def clear_queue_assignment(self, workflow_id: str) -> None:
1397+
with self.engine.begin() as c:
1398+
c.execute(
1399+
sa.update(SystemSchema.workflow_queue)
1400+
.where(SystemSchema.workflow_queue.c.workflow_uuid == workflow_id)
1401+
.values(executor_id=None, started_at_epoch_ms=None)
1402+
)
1403+
c.execute(
1404+
sa.update(SystemSchema.workflow_status)
1405+
.where(SystemSchema.workflow_status.c.workflow_uuid == workflow_id)
1406+
.values(executor_id=None, status=WorkflowStatusString.ENQUEUED.value)
1407+
)
1408+
13781409

13791410
def reset_system_database(config: ConfigFile) -> None:
13801411
sysdb_name = (

tests/test_failures.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from sqlalchemy.exc import InvalidRequestError, OperationalError
1010

1111
# Public API
12-
from dbos import DBOS, GetWorkflowsInput, Queue, SetWorkflowID
12+
from dbos import DBOS, GetWorkflowsInput, SetWorkflowID
1313
from dbos._error import DBOSDeadLetterQueueError, DBOSException
1414
from dbos._sys_db import WorkflowStatusString
1515

tests/test_queue.py

Lines changed: 108 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -615,11 +615,17 @@ def test_step(i: int) -> int:
615615
original_handle = DBOS.start_workflow(test_workflow)
616616
for e in step_events:
617617
e.wait()
618+
e.clear()
619+
618620
assert step_counter == 5
619621

620622
# Recover the workflow, then resume it.
621623
recovery_handles = DBOS.recover_pending_workflows()
624+
# Wait until the 2nd invocation of the workflows are dequeued and executed
625+
for e in step_events:
626+
e.wait()
622627
event.set()
628+
623629
# There should be one handle for the workflow and another for each queued step.
624630
assert len(recovery_handles) == queued_steps + 1
625631
# Verify that both the recovered and original workflows complete correctly.
@@ -639,6 +645,84 @@ def test_step(i: int) -> int:
639645
assert queue_entries_are_cleaned_up(dbos)
640646

641647

648+
def test_queue_concurrency_under_recovery(dbos: DBOS) -> None:
649+
event = threading.Event()
650+
wf_events = [threading.Event() for _ in range(2)]
651+
counter = 0
652+
653+
@DBOS.workflow()
654+
def blocked_workflow(i: int) -> None:
655+
wf_events[i].set()
656+
nonlocal counter
657+
counter += 1
658+
event.wait()
659+
660+
@DBOS.workflow()
661+
def noop() -> None:
662+
pass
663+
664+
queue = Queue("test_queue", concurrency=2)
665+
handle1 = queue.enqueue(blocked_workflow, 0)
666+
handle2 = queue.enqueue(blocked_workflow, 1)
667+
handle3 = queue.enqueue(noop)
668+
669+
# Wait for the two first workflows to be dequeued
670+
for e in wf_events:
671+
e.wait()
672+
e.clear()
673+
674+
assert counter == 2
675+
assert handle1.get_status().status == WorkflowStatusString.PENDING.value
676+
assert handle2.get_status().status == WorkflowStatusString.PENDING.value
677+
assert handle3.get_status().status == WorkflowStatusString.ENQUEUED.value
678+
679+
# Manually update the database to pretend the 3rd workflow is PENDING and comes from another executor
680+
with dbos._sys_db.engine.begin() as c:
681+
query = (
682+
sa.update(SystemSchema.workflow_status)
683+
.values(status=WorkflowStatusString.PENDING.value, executor_id="other")
684+
.where(
685+
SystemSchema.workflow_status.c.workflow_uuid
686+
== handle3.get_workflow_id()
687+
)
688+
)
689+
c.execute(query)
690+
691+
# Trigger workflow recovery. The two first workflows should still be blocked but the 3rd one enqueued
692+
recovered_other_handles = DBOS.recover_pending_workflows(["other"])
693+
assert handle1.get_status().status == WorkflowStatusString.PENDING.value
694+
assert handle2.get_status().status == WorkflowStatusString.PENDING.value
695+
assert len(recovered_other_handles) == 1
696+
assert recovered_other_handles[0].get_workflow_id() == handle3.get_workflow_id()
697+
assert handle3.get_status().status == WorkflowStatusString.ENQUEUED.value
698+
699+
# Trigger workflow recovery for "local". The two first workflows should be re-enqueued then dequeued again
700+
recovered_local_handles = DBOS.recover_pending_workflows(["local"])
701+
assert len(recovered_local_handles) == 2
702+
for h in recovered_local_handles:
703+
assert h.get_workflow_id() in [
704+
handle1.get_workflow_id(),
705+
handle2.get_workflow_id(),
706+
]
707+
for e in wf_events:
708+
e.wait()
709+
assert counter == 4
710+
assert handle1.get_status().status == WorkflowStatusString.PENDING.value
711+
assert handle2.get_status().status == WorkflowStatusString.PENDING.value
712+
# Because tasks are re-enqueued in order, the 3rd task is head of line blocked
713+
assert handle3.get_status().status == WorkflowStatusString.ENQUEUED.value
714+
715+
# Unblock the first two workflows
716+
event.set()
717+
718+
# Verify all queue entries eventually get cleaned up.
719+
assert handle1.get_result() == None
720+
assert handle2.get_result() == None
721+
assert handle3.get_result() == None
722+
assert handle3.get_status().executor_id == "local"
723+
assert queue_entries_are_cleaned_up(dbos)
724+
725+
642726
def test_cancelling_queued_workflows(dbos: DBOS) -> None:
643727
start_event = threading.Event()
644728
blocking_event = threading.Event()
@@ -746,17 +830,28 @@ def regular_workflow() -> None:
746830

747831
# Attempt to recover the blocked workflow the maximum number of times
748832
for i in range(max_recovery_attempts):
833+
start_event.clear()
749834
DBOS.recover_pending_workflows()
835+
start_event.wait()
750836
assert recovery_count == i + 2
751837

752-
# Verify an additional recovery throws a DLQ error and puts the workflow in the DLQ status.
753-
with pytest.raises(Exception) as exc_info:
754-
DBOS.recover_pending_workflows()
755-
assert exc_info.errisinstance(DBOSDeadLetterQueueError)
838+
# Verify an additional recovery throws puts the workflow in the DLQ status.
839+
DBOS.recover_pending_workflows()
840+
# we can't start_event.wait() here because the workflow will never execute
841+
time.sleep(2)
756842
assert (
757843
blocked_handle.get_status().status
758844
== WorkflowStatusString.RETRIES_EXCEEDED.value
759845
)
846+
with dbos._sys_db.engine.begin() as c:
847+
query = sa.select(SystemSchema.workflow_status.c.recovery_attempts).where(
848+
SystemSchema.workflow_status.c.workflow_uuid
849+
== blocked_handle.get_workflow_id()
850+
)
851+
result = c.execute(query)
852+
row = result.fetchone()
853+
assert row is not None
854+
assert row[0] == max_recovery_attempts + 2
760855

761856
# Verify the blocked workflow entering the DLQ lets the regular workflow run
762857
assert regular_handle.get_result() == None
@@ -766,6 +861,15 @@ def regular_workflow() -> None:
766861
assert blocked_handle.get_result() == None
767862
dbos._sys_db.wait_for_buffer_flush()
768863
assert blocked_handle.get_status().status == WorkflowStatusString.SUCCESS.value
864+
with dbos._sys_db.engine.begin() as c:
865+
query = sa.select(SystemSchema.workflow_status.c.recovery_attempts).where(
866+
SystemSchema.workflow_status.c.workflow_uuid
867+
== blocked_handle.get_workflow_id()
868+
)
869+
result = c.execute(query)
870+
row = result.fetchone()
871+
assert row is not None
872+
assert row[0] == max_recovery_attempts + 2
769873

770874
# Verify all queue entries eventually get cleaned up.
771875
assert queue_entries_are_cleaned_up(dbos)

0 commit comments

Comments
 (0)