Skip to content

Commit fdd1e65

Browse files
authored
Allow Resuming and Forking to a Queue (#612)
Closes #504
1 parent 1222f0c commit fdd1e65

File tree

7 files changed

+186
-20
lines changed

7 files changed

+186
-20
lines changed

dbos/_client.py

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -404,27 +404,57 @@ async def delete_workflows_async(
404404
self.delete_workflows, workflow_ids, delete_children=delete_children
405405
)
406406

407-
def resume_workflow(self, workflow_id: str) -> "WorkflowHandle[Any]":
408-
self._sys_db.resume_workflows([workflow_id])
407+
def resume_workflow(
408+
self,
409+
workflow_id: str,
410+
*,
411+
queue_name: Optional[str] = None,
412+
) -> "WorkflowHandle[Any]":
413+
self._sys_db.resume_workflows(
414+
[workflow_id],
415+
queue_name=queue_name,
416+
)
409417
return WorkflowHandleClientPolling[Any](workflow_id, self._sys_db)
410418

411419
async def resume_workflow_async(
412-
self, workflow_id: str
420+
self,
421+
workflow_id: str,
422+
*,
423+
queue_name: Optional[str] = None,
413424
) -> "WorkflowHandleAsync[Any]":
414-
await asyncio.to_thread(self.resume_workflow, workflow_id)
425+
await asyncio.to_thread(
426+
self.resume_workflow,
427+
workflow_id,
428+
queue_name=queue_name,
429+
)
415430
return WorkflowHandleClientAsyncPolling[Any](workflow_id, self._sys_db)
416431

417-
def resume_workflows(self, workflow_ids: List[str]) -> "List[WorkflowHandle[Any]]":
418-
self._sys_db.resume_workflows(workflow_ids)
432+
def resume_workflows(
433+
self,
434+
workflow_ids: List[str],
435+
*,
436+
queue_name: Optional[str] = None,
437+
) -> "List[WorkflowHandle[Any]]":
438+
self._sys_db.resume_workflows(
439+
workflow_ids,
440+
queue_name=queue_name,
441+
)
419442
return [
420443
WorkflowHandleClientPolling[Any](wfid, self._sys_db)
421444
for wfid in workflow_ids
422445
]
423446

424447
async def resume_workflows_async(
425-
self, workflow_ids: List[str]
448+
self,
449+
workflow_ids: List[str],
450+
*,
451+
queue_name: Optional[str] = None,
426452
) -> "List[WorkflowHandleAsync[Any]]":
427-
await asyncio.to_thread(self._sys_db.resume_workflows, workflow_ids)
453+
await asyncio.to_thread(
454+
self._sys_db.resume_workflows,
455+
workflow_ids,
456+
queue_name=queue_name,
457+
)
428458
return [
429459
WorkflowHandleClientAsyncPolling[Any](wfid, self._sys_db)
430460
for wfid in workflow_ids
@@ -613,12 +643,16 @@ def fork_workflow(
613643
start_step: int,
614644
*,
615645
application_version: Optional[str] = None,
646+
queue_name: Optional[str] = None,
647+
queue_partition_key: Optional[str] = None,
616648
) -> "WorkflowHandle[Any]":
617649
forked_workflow_id = fork_workflow(
618650
self._sys_db,
619651
workflow_id,
620652
start_step,
621653
application_version=application_version,
654+
queue_name=queue_name,
655+
queue_partition_key=queue_partition_key,
622656
)
623657
return WorkflowHandleClientPolling[Any](forked_workflow_id, self._sys_db)
624658

@@ -628,13 +662,17 @@ async def fork_workflow_async(
628662
start_step: int,
629663
*,
630664
application_version: Optional[str] = None,
665+
queue_name: Optional[str] = None,
666+
queue_partition_key: Optional[str] = None,
631667
) -> "WorkflowHandleAsync[Any]":
632668
forked_workflow_id = await asyncio.to_thread(
633669
fork_workflow,
634670
self._sys_db,
635671
workflow_id,
636672
start_step,
637673
application_version=application_version,
674+
queue_name=queue_name,
675+
queue_partition_key=queue_partition_key,
638676
)
639677
return WorkflowHandleClientAsyncPolling[Any](forked_workflow_id, self._sys_db)
640678

dbos/_conductor/conductor.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,10 @@ def run(self) -> None:
193193
]
194194
success = True
195195
try:
196-
self.dbos.resume_workflows(resume_ids)
196+
self.dbos.resume_workflows(
197+
resume_ids,
198+
queue_name=resume_message.queue_name,
199+
)
197200
except Exception as e:
198201
error_message = f"Exception encountered when resuming workflow(s) {resume_ids}: {traceback.format_exc()}"
199202
self.dbos.logger.error(error_message)
@@ -230,12 +233,18 @@ def run(self) -> None:
230233
workflow_id = fork_message.body["workflow_id"]
231234
start_step = fork_message.body["start_step"]
232235
app_version = fork_message.body["application_version"]
236+
queue_name = fork_message.body.get("queue_name")
237+
queue_partition_key = fork_message.body.get(
238+
"queue_partition_key"
239+
)
233240
try:
234241
with SetWorkflowID(new_workflow_id):
235242
new_handle = self.dbos.fork_workflow(
236243
workflow_id,
237244
start_step,
238245
application_version=app_version,
246+
queue_name=queue_name,
247+
queue_partition_key=queue_partition_key,
239248
)
240249
new_workflow_id = new_handle.workflow_id
241250
except Exception as e:

dbos/_conductor/protocol.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ class DeleteResponse(BaseMessage):
131131
class ResumeRequest(BaseMessage):
132132
workflow_id: str
133133
workflow_ids: Optional[List[str]] = None
134+
queue_name: Optional[str] = None
134135

135136

136137
@dataclass
@@ -370,6 +371,8 @@ class ForkWorkflowBody(TypedDict):
370371
start_step: int
371372
application_version: Optional[str]
372373
new_workflow_id: Optional[str]
374+
queue_name: Optional[str]
375+
queue_partition_key: Optional[str]
373376

374377

375378
@dataclass

dbos/_dbos.py

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1654,28 +1654,44 @@ async def _configure_asyncio_thread_pool(cls) -> None:
16541654
loop.set_default_executor(_get_dbos_instance()._executor)
16551655

16561656
@classmethod
1657-
def resume_workflow(cls, workflow_id: str) -> WorkflowHandle[Any]:
1657+
def resume_workflow(
1658+
cls,
1659+
workflow_id: str,
1660+
*,
1661+
queue_name: Optional[str] = None,
1662+
) -> WorkflowHandle[Any]:
16581663
"""Resume a workflow by ID."""
16591664
check_async("resume_workflow")
16601665

16611666
def fn() -> None:
16621667
dbos_logger.info(f"Resuming workflow: {workflow_id}")
1663-
_get_dbos_instance()._sys_db.resume_workflows([workflow_id])
1668+
_get_dbos_instance()._sys_db.resume_workflows(
1669+
[workflow_id],
1670+
queue_name=queue_name,
1671+
)
16641672

16651673
_get_dbos_instance()._sys_db.call_function_as_step(
16661674
fn, "DBOS.resumeWorkflow", snapshot_step_context(reserve_sleep_id=False)
16671675
)
16681676
return WorkflowHandlePolling(workflow_id, _get_dbos_instance())
16691677

16701678
@classmethod
1671-
async def resume_workflow_async(cls, workflow_id: str) -> WorkflowHandleAsync[Any]:
1679+
async def resume_workflow_async(
1680+
cls,
1681+
workflow_id: str,
1682+
*,
1683+
queue_name: Optional[str] = None,
1684+
) -> WorkflowHandleAsync[Any]:
16721685
"""Resume a workflow by ID."""
16731686
step_ctx_res = snapshot_step_context(reserve_sleep_id=False)
16741687
await cls._configure_asyncio_thread_pool()
16751688

16761689
def fnres() -> None:
16771690
dbos_logger.info(f"Resuming workflow: {workflow_id}")
1678-
_get_dbos_instance()._sys_db.resume_workflows([workflow_id])
1691+
_get_dbos_instance()._sys_db.resume_workflows(
1692+
[workflow_id],
1693+
queue_name=queue_name,
1694+
)
16791695

16801696
await asyncio.to_thread(
16811697
_get_dbos_instance()._sys_db.call_function_as_step,
@@ -1686,13 +1702,21 @@ def fnres() -> None:
16861702
return WorkflowHandleAsyncPolling(workflow_id, _get_dbos_instance())
16871703

16881704
@classmethod
1689-
def resume_workflows(cls, workflow_ids: List[str]) -> List[WorkflowHandle[Any]]:
1705+
def resume_workflows(
1706+
cls,
1707+
workflow_ids: List[str],
1708+
*,
1709+
queue_name: Optional[str] = None,
1710+
) -> List[WorkflowHandle[Any]]:
16901711
"""Resume multiple workflows by ID."""
16911712
check_async("resume_workflows")
16921713

16931714
def fn() -> None:
16941715
dbos_logger.info(f"Resuming workflows: {workflow_ids}")
1695-
_get_dbos_instance()._sys_db.resume_workflows(workflow_ids)
1716+
_get_dbos_instance()._sys_db.resume_workflows(
1717+
workflow_ids,
1718+
queue_name=queue_name,
1719+
)
16961720

16971721
_get_dbos_instance()._sys_db.call_function_as_step(
16981722
fn, "DBOS.resumeWorkflow", snapshot_step_context(reserve_sleep_id=False)
@@ -1703,15 +1727,21 @@ def fn() -> None:
17031727

17041728
@classmethod
17051729
async def resume_workflows_async(
1706-
cls, workflow_ids: List[str]
1730+
cls,
1731+
workflow_ids: List[str],
1732+
*,
1733+
queue_name: Optional[str] = None,
17071734
) -> List[WorkflowHandleAsync[Any]]:
17081735
"""Resume multiple workflows by ID."""
17091736
step_ctx_res = snapshot_step_context(reserve_sleep_id=False)
17101737
await cls._configure_asyncio_thread_pool()
17111738

17121739
def fnres() -> None:
17131740
dbos_logger.info(f"Resuming workflows: {workflow_ids}")
1714-
_get_dbos_instance()._sys_db.resume_workflows(workflow_ids)
1741+
_get_dbos_instance()._sys_db.resume_workflows(
1742+
workflow_ids,
1743+
queue_name=queue_name,
1744+
)
17151745

17161746
await asyncio.to_thread(
17171747
_get_dbos_instance()._sys_db.call_function_as_step,
@@ -1732,6 +1762,8 @@ def fork_workflow(
17321762
start_step: int,
17331763
*,
17341764
application_version: Optional[str] = None,
1765+
queue_name: Optional[str] = None,
1766+
queue_partition_key: Optional[str] = None,
17351767
) -> WorkflowHandle[Any]:
17361768
"""Restart a workflow with a new workflow ID from a specific step"""
17371769
check_async("fork_workflow")
@@ -1743,6 +1775,8 @@ def fn() -> str:
17431775
workflow_id,
17441776
start_step,
17451777
application_version=application_version,
1778+
queue_name=queue_name,
1779+
queue_partition_key=queue_partition_key,
17461780
)
17471781

17481782
new_id = _get_dbos_instance()._sys_db.call_function_as_step(
@@ -1757,6 +1791,8 @@ async def fork_workflow_async(
17571791
start_step: int,
17581792
*,
17591793
application_version: Optional[str] = None,
1794+
queue_name: Optional[str] = None,
1795+
queue_partition_key: Optional[str] = None,
17601796
) -> WorkflowHandleAsync[Any]:
17611797
"""Restart a workflow with a new workflow ID from a specific step"""
17621798
step_ctx_res = snapshot_step_context(reserve_sleep_id=False)
@@ -1770,6 +1806,8 @@ def fn() -> str:
17701806
workflow_id,
17711807
start_step,
17721808
application_version=application_version,
1809+
queue_name=queue_name,
1810+
queue_partition_key=queue_partition_key,
17731811
)
17741812

17751813
new_id = await asyncio.to_thread(

dbos/_sys_db.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -726,7 +726,12 @@ def cancel_workflows(
726726
)
727727
)
728728

729-
def resume_workflows(self, workflow_ids: list[str]) -> None:
729+
def resume_workflows(
730+
self,
731+
workflow_ids: list[str],
732+
*,
733+
queue_name: Optional[str] = None,
734+
) -> None:
730735
with self.engine.begin() as c:
731736
# Set the workflows' status to ENQUEUED and clear recovery attempts and deadline,
732737
# but only if the workflow is not already complete.
@@ -743,7 +748,9 @@ def resume_workflows(self, workflow_ids: list[str]) -> None:
743748
)
744749
.values(
745750
status=WorkflowStatusString.ENQUEUED.value,
746-
queue_name=INTERNAL_QUEUE_NAME,
751+
queue_name=(
752+
queue_name if queue_name is not None else INTERNAL_QUEUE_NAME
753+
),
747754
recovery_attempts=0,
748755
workflow_deadline_epoch_ms=None,
749756
deduplication_id=None,
@@ -768,6 +775,8 @@ def fork_workflow(
768775
start_step: int,
769776
*,
770777
application_version: Optional[str],
778+
queue_name: Optional[str] = None,
779+
queue_partition_key: Optional[str] = None,
771780
) -> str:
772781

773782
status = self.get_workflow_status(original_workflow_id)
@@ -789,7 +798,10 @@ def fork_workflow(
789798
authenticated_user=status["authenticated_user"],
790799
authenticated_roles=status["authenticated_roles"],
791800
serialization=status["serialization"],
792-
queue_name=INTERNAL_QUEUE_NAME,
801+
queue_name=(
802+
queue_name if queue_name is not None else INTERNAL_QUEUE_NAME
803+
),
804+
queue_partition_key=queue_partition_key,
793805
inputs=status["inputs"],
794806
assumed_role=status["assumed_role"],
795807
forked_from=original_workflow_id,

dbos/_workflow_commands.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ def fork_workflow(
2323
start_step: int,
2424
*,
2525
application_version: Optional[str],
26+
queue_name: Optional[str] = None,
27+
queue_partition_key: Optional[str] = None,
2628
) -> str:
2729

2830
ctx = get_local_dbos_context()
@@ -36,6 +38,8 @@ def fork_workflow(
3638
forked_workflow_id,
3739
start_step,
3840
application_version=application_version,
41+
queue_name=queue_name,
42+
queue_partition_key=queue_partition_key,
3943
)
4044
return forked_workflow_id
4145

0 commit comments

Comments
 (0)