diff --git a/dbos/_client.py b/dbos/_client.py index c0e2c084..848ed2bb 100644 --- a/dbos/_client.py +++ b/dbos/_client.py @@ -364,13 +364,47 @@ async def get_event_async( ) def cancel_workflow(self, workflow_id: str) -> None: - self._sys_db.cancel_workflow(workflow_id) + self._sys_db.cancel_workflows([workflow_id]) async def cancel_workflow_async(self, workflow_id: str) -> None: await asyncio.to_thread(self.cancel_workflow, workflow_id) + def cancel_workflows(self, workflow_ids: List[str]) -> None: + self._sys_db.cancel_workflows(workflow_ids) + + async def cancel_workflows_async(self, workflow_ids: List[str]) -> None: + await asyncio.to_thread(self._sys_db.cancel_workflows, workflow_ids) + + def delete_workflow( + self, workflow_id: str, *, delete_children: bool = False + ) -> None: + self.delete_workflows([workflow_id], delete_children=delete_children) + + async def delete_workflow_async( + self, workflow_id: str, *, delete_children: bool = False + ) -> None: + await asyncio.to_thread( + self.delete_workflows, [workflow_id], delete_children=delete_children + ) + + def delete_workflows( + self, workflow_ids: List[str], *, delete_children: bool = False + ) -> None: + all_ids = list(workflow_ids) + if delete_children: + for wfid in workflow_ids: + all_ids.extend(self._sys_db.get_workflow_children(wfid)) + self._sys_db.delete_workflows(all_ids) + + async def delete_workflows_async( + self, workflow_ids: List[str], *, delete_children: bool = False + ) -> None: + await asyncio.to_thread( + self.delete_workflows, workflow_ids, delete_children=delete_children + ) + def resume_workflow(self, workflow_id: str) -> "WorkflowHandle[Any]": - self._sys_db.resume_workflow(workflow_id) + self._sys_db.resume_workflows([workflow_id]) return WorkflowHandleClientPolling[Any](workflow_id, self._sys_db) async def resume_workflow_async( @@ -379,6 +413,22 @@ async def resume_workflow_async( await asyncio.to_thread(self.resume_workflow, workflow_id) return WorkflowHandleClientAsyncPolling[Any](workflow_id, self._sys_db) + def resume_workflows(self, workflow_ids: List[str]) -> "List[WorkflowHandle[Any]]": + self._sys_db.resume_workflows(workflow_ids) + return [ + WorkflowHandleClientPolling[Any](wfid, self._sys_db) + for wfid in workflow_ids + ] + + async def resume_workflows_async( + self, workflow_ids: List[str] + ) -> "List[WorkflowHandleAsync[Any]]": + await asyncio.to_thread(self._sys_db.resume_workflows, workflow_ids) + return [ + WorkflowHandleClientAsyncPolling[Any](wfid, self._sys_db) + for wfid in workflow_ids + ] + def list_workflows( self, *, diff --git a/dbos/_conductor/conductor.py b/dbos/_conductor/conductor.py index d785b538..f83119cb 100644 --- a/dbos/_conductor/conductor.py +++ b/dbos/_conductor/conductor.py @@ -7,7 +7,7 @@ import traceback from datetime import datetime from importlib.metadata import version -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, List, Optional from websockets import ConnectionClosed, ConnectionClosedOK, InvalidStatus from websockets.sync.client import connect @@ -146,11 +146,14 @@ def run(self) -> None: websocket.send(recovery_response.to_json()) elif msg_type == p.MessageType.CANCEL: cancel_message = p.CancelRequest.from_json(message) + cancel_ids = cancel_message.workflow_ids or [ + cancel_message.workflow_id + ] success = True try: - self.dbos.cancel_workflow(cancel_message.workflow_id) + self.dbos.cancel_workflows(cancel_ids) except Exception as e: - error_message = f"Exception encountered when cancelling workflow {cancel_message.workflow_id}: {traceback.format_exc()}" + error_message = f"Exception encountered when cancelling workflow(s) {cancel_ids}: {traceback.format_exc()}" self.dbos.logger.error(error_message) success = False cancel_response = p.CancelResponse( @@ -162,15 +165,18 @@ def run(self) -> None: websocket.send(cancel_response.to_json()) elif msg_type == p.MessageType.DELETE: delete_message = p.DeleteRequest.from_json(message) + delete_ids = delete_message.workflow_ids or [ + delete_message.workflow_id + ] success = True try: delete_workflow( self.dbos, - delete_message.workflow_id, + delete_ids, delete_children=delete_message.delete_children, ) except Exception as e: - error_message = f"Exception encountered when deleting workflow {delete_message.workflow_id}: {traceback.format_exc()}" + error_message = f"Exception encountered when deleting workflow(s) {delete_ids}: {traceback.format_exc()}" self.dbos.logger.error(error_message) success = False delete_response = p.DeleteResponse( @@ -182,11 +188,14 @@ def run(self) -> None: websocket.send(delete_response.to_json()) elif msg_type == p.MessageType.RESUME: resume_message = p.ResumeRequest.from_json(message) + resume_ids = resume_message.workflow_ids or [ + resume_message.workflow_id + ] success = True try: - self.dbos.resume_workflow(resume_message.workflow_id) + self.dbos.resume_workflows(resume_ids) except Exception as e: - error_message = f"Exception encountered when resuming workflow {resume_message.workflow_id}: {traceback.format_exc()}" + error_message = f"Exception encountered when resuming workflow(s) {resume_ids}: {traceback.format_exc()}" self.dbos.logger.error(error_message) success = False resume_response = p.ResumeResponse( @@ -357,6 +366,81 @@ def run(self) -> None: error_message=error_message, ) websocket.send(get_workflow_response.to_json()) + elif msg_type == p.MessageType.GET_WORKFLOW_EVENTS: + events_message = p.GetWorkflowEventsRequest.from_json( + message + ) + event_outputs: Optional[List[p.EventOutput]] = None + error_message = None + try: + raw_events = self.dbos._sys_db.get_all_events( + events_message.workflow_id + ) + event_outputs = [ + p.EventOutput.from_event_data(k, v) + for k, v in raw_events.items() + ] + except Exception as e: + error_message = f"Exception encountered when getting events for workflow {events_message.workflow_id}: {traceback.format_exc()}" + self.dbos.logger.error(error_message) + websocket.send( + p.GetWorkflowEventsResponse( + type=p.MessageType.GET_WORKFLOW_EVENTS, + request_id=base_message.request_id, + events=event_outputs, + error_message=error_message, + ).to_json() + ) + elif msg_type == p.MessageType.GET_WORKFLOW_NOTIFICATIONS: + notif_message = p.GetWorkflowNotificationsRequest.from_json( + message + ) + notif_outputs: Optional[List[p.NotificationOutput]] = None + error_message = None + try: + raw_notifs = self.dbos._sys_db.get_all_notifications( + notif_message.workflow_id + ) + notif_outputs = [ + p.NotificationOutput.from_notification_info(n) + for n in raw_notifs + ] + except Exception as e: + error_message = f"Exception encountered when getting notifications for workflow {notif_message.workflow_id}: {traceback.format_exc()}" + self.dbos.logger.error(error_message) + websocket.send( + p.GetWorkflowNotificationsResponse( + type=p.MessageType.GET_WORKFLOW_NOTIFICATIONS, + request_id=base_message.request_id, + notifications=notif_outputs, + error_message=error_message, + ).to_json() + ) + elif msg_type == p.MessageType.GET_WORKFLOW_STREAMS: + streams_message = p.GetWorkflowStreamsRequest.from_json( + message + ) + stream_outputs: Optional[List[p.StreamEntryOutput]] = None + error_message = None + try: + raw_streams = self.dbos._sys_db.get_all_stream_entries( + streams_message.workflow_id + ) + stream_outputs = [ + p.StreamEntryOutput.from_stream_data(k, v) + for k, v in raw_streams.items() + ] + except Exception as e: + error_message = f"Exception encountered when getting streams for workflow {streams_message.workflow_id}: {traceback.format_exc()}" + self.dbos.logger.error(error_message) + websocket.send( + p.GetWorkflowStreamsResponse( + type=p.MessageType.GET_WORKFLOW_STREAMS, + request_id=base_message.request_id, + streams=stream_outputs, + error_message=error_message, + ).to_json() + ) elif msg_type == p.MessageType.EXIST_PENDING_WORKFLOWS: exist_pending_workflows_message = ( p.ExistPendingWorkflowsRequest.from_json(message) diff --git a/dbos/_conductor/protocol.py b/dbos/_conductor/protocol.py index af529c15..e09ed533 100644 --- a/dbos/_conductor/protocol.py +++ b/dbos/_conductor/protocol.py @@ -4,7 +4,13 @@ from typing import Dict, List, Optional, Type, TypedDict, TypeVar, Union from dbos._serialization import Serializer -from dbos._sys_db import StepInfo, VersionInfo, WorkflowSchedule, WorkflowStatus +from dbos._sys_db import ( + NotificationInfo, + StepInfo, + VersionInfo, + WorkflowSchedule, + WorkflowStatus, +) class MessageType(str, Enum): @@ -33,6 +39,9 @@ class MessageType(str, Enum): TRIGGER_SCHEDULE = "trigger_schedule" LIST_APPLICATION_VERSIONS = "list_application_versions" SET_LATEST_APPLICATION_VERSION = "set_latest_application_version" + GET_WORKFLOW_EVENTS = "get_workflow_events" + GET_WORKFLOW_NOTIFICATIONS = "get_workflow_notifications" + GET_WORKFLOW_STREAMS = "get_workflow_streams" T = TypeVar("T", bound="BaseMessage") @@ -96,6 +105,7 @@ class RecoveryResponse(BaseMessage): @dataclass class CancelRequest(BaseMessage): workflow_id: str + workflow_ids: Optional[List[str]] = None @dataclass @@ -108,6 +118,7 @@ class CancelResponse(BaseMessage): class DeleteRequest(BaseMessage): workflow_id: str delete_children: bool + workflow_ids: Optional[List[str]] = None @dataclass @@ -119,6 +130,7 @@ class DeleteResponse(BaseMessage): @dataclass class ResumeRequest(BaseMessage): workflow_id: str + workflow_ids: Optional[List[str]] = None @dataclass @@ -587,3 +599,73 @@ class SetLatestApplicationVersionRequest(BaseMessage): class SetLatestApplicationVersionResponse(BaseMessage): success: bool error_message: Optional[str] = None + + +@dataclass +class NotificationOutput: + topic: Optional[str] + message: str + created_at_epoch_ms: int + consumed: bool + + @classmethod + def from_notification_info(cls, info: NotificationInfo) -> "NotificationOutput": + return cls( + topic=info["topic"], + message=str(info["message"]), + created_at_epoch_ms=info["created_at_epoch_ms"], + consumed=info["consumed"], + ) + + +@dataclass +class StreamEntryOutput: + key: str + values: List[str] + + @classmethod + def from_stream_data(cls, key: str, values: List[object]) -> "StreamEntryOutput": + return cls(key=key, values=[str(v) for v in values]) + + +@dataclass +class EventOutput: + key: str + value: str + + @classmethod + def from_event_data(cls, key: str, value: object) -> "EventOutput": + return cls(key=key, value=str(value)) + + +@dataclass +class GetWorkflowEventsRequest(BaseMessage): + workflow_id: str + + +@dataclass +class GetWorkflowEventsResponse(BaseMessage): + events: Optional[List[EventOutput]] + error_message: Optional[str] = None + + +@dataclass +class GetWorkflowNotificationsRequest(BaseMessage): + workflow_id: str + + +@dataclass +class GetWorkflowNotificationsResponse(BaseMessage): + notifications: Optional[List[NotificationOutput]] + error_message: Optional[str] = None + + +@dataclass +class GetWorkflowStreamsRequest(BaseMessage): + workflow_id: str + + +@dataclass +class GetWorkflowStreamsResponse(BaseMessage): + streams: Optional[List[StreamEntryOutput]] + error_message: Optional[str] = None diff --git a/dbos/_core.py b/dbos/_core.py index 46003467..bad2571c 100644 --- a/dbos/_core.py +++ b/dbos/_core.py @@ -463,7 +463,7 @@ def timeout_func() -> None: was_stopped = evt.wait(time_to_wait_sec) if was_stopped: return - dbos._sys_db.cancel_workflow(wfid) + dbos._sys_db.cancel_workflows([wfid]) except Exception as e: dbos.logger.warning( f"Exception in timeout thread for workflow {wfid}: {e}" diff --git a/dbos/_dbos.py b/dbos/_dbos.py index 08507e4a..711094da 100644 --- a/dbos/_dbos.py +++ b/dbos/_dbos.py @@ -1468,25 +1468,35 @@ def _recover_pending_workflows( @classmethod def cancel_workflow(cls, workflow_id: str) -> None: """Cancel a workflow by ID.""" - check_async("cancel_workflow") + cls.cancel_workflows([workflow_id]) + + @classmethod + async def cancel_workflow_async(cls, workflow_id: str) -> None: + """Cancel a workflow by ID.""" + await cls.cancel_workflows_async([workflow_id]) + + @classmethod + def cancel_workflows(cls, workflow_ids: List[str]) -> None: + """Cancel multiple workflows by ID.""" + check_async("cancel_workflows") def fn() -> None: - dbos_logger.info(f"Cancelling workflow: {workflow_id}") - _get_dbos_instance()._sys_db.cancel_workflow(workflow_id) + dbos_logger.info(f"Cancelling workflow(s): {workflow_ids}") + _get_dbos_instance()._sys_db.cancel_workflows(workflow_ids) return _get_dbos_instance()._sys_db.call_function_as_step( fn, "DBOS.cancelWorkflow", snapshot_step_context(reserve_sleep_id=False) ) @classmethod - async def cancel_workflow_async(cls, workflow_id: str) -> None: - """Cancel a workflow by ID.""" + async def cancel_workflows_async(cls, workflow_ids: List[str]) -> None: + """Cancel multiple workflows by ID.""" step_ctx = snapshot_step_context(reserve_sleep_id=False) await cls._configure_asyncio_thread_pool() def fn() -> None: - dbos_logger.info(f"Cancelling workflow: {workflow_id}") - _get_dbos_instance()._sys_db.cancel_workflow(workflow_id) + dbos_logger.info(f"Cancelling workflow(s): {workflow_ids}") + _get_dbos_instance()._sys_db.cancel_workflows(workflow_ids) return await asyncio.to_thread( _get_dbos_instance()._sys_db.call_function_as_step, @@ -1499,16 +1509,30 @@ def fn() -> None: def delete_workflow( cls, workflow_id: str, *, delete_children: bool = False ) -> None: - """Delete a workflow and all its associated data by ID. + """Delete a workflow and all its associated data by ID.""" + cls.delete_workflows([workflow_id], delete_children=delete_children) + + @classmethod + async def delete_workflow_async( + cls, workflow_id: str, *, delete_children: bool = False + ) -> None: + """Delete a workflow and all its associated data by ID.""" + await cls.delete_workflows_async([workflow_id], delete_children=delete_children) + + @classmethod + def delete_workflows( + cls, workflow_ids: List[str], *, delete_children: bool = False + ) -> None: + """Delete multiple workflows and all their associated data by ID. If delete_children is True, also deletes all child workflows recursively. """ - check_async("delete_workflow") + check_async("delete_workflows") def fn() -> None: - dbos_logger.info(f"Deleting workflow: {workflow_id}") + dbos_logger.info(f"Deleting workflow(s): {workflow_ids}") delete_workflow( - _get_dbos_instance(), workflow_id, delete_children=delete_children + _get_dbos_instance(), workflow_ids, delete_children=delete_children ) return _get_dbos_instance()._sys_db.call_function_as_step( @@ -1516,10 +1540,10 @@ def fn() -> None: ) @classmethod - async def delete_workflow_async( - cls, workflow_id: str, *, delete_children: bool = False + async def delete_workflows_async( + cls, workflow_ids: List[str], *, delete_children: bool = False ) -> None: - """Delete a workflow and all its associated data by ID. + """Delete multiple workflows and all their associated data by ID. If delete_children is True, also deletes all child workflows recursively. """ @@ -1527,9 +1551,9 @@ async def delete_workflow_async( await cls._configure_asyncio_thread_pool() def fn() -> None: - dbos_logger.info(f"Deleting workflow: {workflow_id}") + dbos_logger.info(f"Deleting workflow(s): {workflow_ids}") delete_workflow( - _get_dbos_instance(), workflow_id, delete_children=delete_children + _get_dbos_instance(), workflow_ids, delete_children=delete_children ) return await asyncio.to_thread( @@ -1556,23 +1580,22 @@ def resume_workflow(cls, workflow_id: str) -> WorkflowHandle[Any]: def fn() -> None: dbos_logger.info(f"Resuming workflow: {workflow_id}") - _get_dbos_instance()._sys_db.resume_workflow(workflow_id) + _get_dbos_instance()._sys_db.resume_workflows([workflow_id]) _get_dbos_instance()._sys_db.call_function_as_step( fn, "DBOS.resumeWorkflow", snapshot_step_context(reserve_sleep_id=False) ) - return cls.retrieve_workflow(workflow_id) + return WorkflowHandlePolling(workflow_id, _get_dbos_instance()) @classmethod async def resume_workflow_async(cls, workflow_id: str) -> WorkflowHandleAsync[Any]: """Resume a workflow by ID.""" step_ctx_res = snapshot_step_context(reserve_sleep_id=False) - step_ctx_ret = snapshot_step_context(reserve_sleep_id=False) await cls._configure_asyncio_thread_pool() def fnres() -> None: dbos_logger.info(f"Resuming workflow: {workflow_id}") - _get_dbos_instance()._sys_db.resume_workflow(workflow_id) + _get_dbos_instance()._sys_db.resume_workflows([workflow_id]) await asyncio.to_thread( _get_dbos_instance()._sys_db.call_function_as_step, @@ -1580,19 +1603,47 @@ def fnres() -> None: "DBOS.resumeWorkflow", step_ctx_res, ) + return WorkflowHandleAsyncPolling(workflow_id, _get_dbos_instance()) - def fnret() -> Optional[WorkflowStatus]: - return get_workflow(_get_dbos_instance()._sys_db, workflow_id) + @classmethod + def resume_workflows(cls, workflow_ids: List[str]) -> List[WorkflowHandle[Any]]: + """Resume multiple workflows by ID.""" + check_async("resume_workflows") - stat = await asyncio.to_thread( + def fn() -> None: + dbos_logger.info(f"Resuming workflows: {workflow_ids}") + _get_dbos_instance()._sys_db.resume_workflows(workflow_ids) + + _get_dbos_instance()._sys_db.call_function_as_step( + fn, "DBOS.resumeWorkflow", snapshot_step_context(reserve_sleep_id=False) + ) + return [ + WorkflowHandlePolling(wfid, _get_dbos_instance()) for wfid in workflow_ids + ] + + @classmethod + async def resume_workflows_async( + cls, workflow_ids: List[str] + ) -> List[WorkflowHandleAsync[Any]]: + """Resume multiple workflows by ID.""" + step_ctx_res = snapshot_step_context(reserve_sleep_id=False) + await cls._configure_asyncio_thread_pool() + + def fnres() -> None: + dbos_logger.info(f"Resuming workflows: {workflow_ids}") + _get_dbos_instance()._sys_db.resume_workflows(workflow_ids) + + await asyncio.to_thread( _get_dbos_instance()._sys_db.call_function_as_step, - fnret, - "DBOS.getStatus", - step_ctx_ret, + fnres, + "DBOS.resumeWorkflow", + step_ctx_res, ) - if stat is None: - raise DBOSNonExistentWorkflowError("target", workflow_id) - return WorkflowHandleAsyncPolling(workflow_id, _get_dbos_instance()) + + return [ + WorkflowHandleAsyncPolling(wfid, _get_dbos_instance()) + for wfid in workflow_ids + ] @classmethod def fork_workflow( diff --git a/dbos/_sys_db.py b/dbos/_sys_db.py index 653d3a9b..f81de7ce 100644 --- a/dbos/_sys_db.py +++ b/dbos/_sys_db.py @@ -285,6 +285,13 @@ class StepInfo(TypedDict): completed_at_epoch_ms: Optional[int] +class NotificationInfo(TypedDict): + topic: Optional[str] + message: Any + created_at_epoch_ms: int + consumed: bool + + _dbos_null_topic = "__null__topic__" _dbos_stream_closed_sentinel = "__DBOS_STREAM_CLOSED__" @@ -687,16 +694,16 @@ def update_workflow_outcome( .where(SystemSchema.workflow_status.c.workflow_uuid == workflow_id) ) - def cancel_workflow( + def cancel_workflows( self, - workflow_id: str, + workflow_ids: list[str], ) -> None: with self.engine.begin() as c: - # Set the workflow's status to CANCELLED and remove it from any queue it is on, + # Set the workflows' status to CANCELLED and remove them from any queue, # but only if the workflow is not already complete. c.execute( sa.update(SystemSchema.workflow_status) - .where(SystemSchema.workflow_status.c.workflow_uuid == workflow_id) + .where(SystemSchema.workflow_status.c.workflow_uuid.in_(workflow_ids)) .where( SystemSchema.workflow_status.c.status.notin_( [ @@ -714,13 +721,13 @@ def cancel_workflow( ) ) - def resume_workflow(self, workflow_id: str) -> None: + def resume_workflows(self, workflow_ids: list[str]) -> None: with self.engine.begin() as c: - # Set the workflow's status to ENQUEUED and clear its recovery attempts and deadline, + # Set the workflows' status to ENQUEUED and clear recovery attempts and deadline, # but only if the workflow is not already complete. c.execute( sa.update(SystemSchema.workflow_status) - .where(SystemSchema.workflow_status.c.workflow_uuid == workflow_id) + .where(SystemSchema.workflow_status.c.workflow_uuid.in_(workflow_ids)) .where( SystemSchema.workflow_status.c.status.notin_( [ @@ -2182,6 +2189,67 @@ def get_all_events(self, workflow_id: str) -> Dict[str, Any]: return events + def get_all_notifications(self, workflow_id: str) -> List[NotificationInfo]: + """Get all notifications sent to a workflow.""" + with self.engine.begin() as c: + rows = c.execute( + sa.select( + SystemSchema.notifications.c.topic, + SystemSchema.notifications.c.message, + SystemSchema.notifications.c.serialization, + SystemSchema.notifications.c.created_at_epoch_ms, + SystemSchema.notifications.c.consumed, + ) + .where(SystemSchema.notifications.c.destination_uuid == workflow_id) + .order_by(SystemSchema.notifications.c.created_at_epoch_ms) + ).fetchall() + results: List[NotificationInfo] = [] + for row in rows: + topic = row[0] + if topic == _dbos_null_topic: + topic = None + results.append( + { + "topic": topic, + "message": deserialize_value(row[1], row[2], self.serializer), + "created_at_epoch_ms": row[3], + "consumed": row[4], + } + ) + return results + + def get_all_stream_entries(self, workflow_id: str) -> Dict[str, List[Any]]: + """Get all stream entries for a workflow. + + Returns a dict mapping stream keys to lists of deserialized values (ordered by offset). + """ + with self.engine.begin() as c: + rows = c.execute( + sa.select( + SystemSchema.streams.c.key, + SystemSchema.streams.c.value, + SystemSchema.streams.c.offset, + SystemSchema.streams.c.serialization, + ) + .where(SystemSchema.streams.c.workflow_uuid == workflow_id) + .order_by( + SystemSchema.streams.c.key, + SystemSchema.streams.c.offset, + ) + ).fetchall() + streams: Dict[str, List[Any]] = {} + for row in rows: + key = row[0] + value_str = row[1] + serialization = row[3] + value = deserialize_value(value_str, serialization, self.serializer) + if value == _dbos_stream_closed_sentinel: + continue + if key not in streams: + streams[key] = [] + streams[key].append(value) + return streams + @db_retry() def get_event_setup( self, diff --git a/dbos/_workflow_commands.py b/dbos/_workflow_commands.py index e004b275..4f8fd258 100644 --- a/dbos/_workflow_commands.py +++ b/dbos/_workflow_commands.py @@ -40,17 +40,20 @@ def fork_workflow( return forked_workflow_id -def delete_workflow(dbos: "DBOS", workflow_id: str, *, delete_children: bool) -> None: - """Delete a workflow and all its associated data. +def delete_workflow( + dbos: "DBOS", workflow_ids: list[str], *, delete_children: bool +) -> None: + """Delete workflows and all their associated data. If delete_children is True, also deletes all child workflows recursively. """ - workflow_ids = [workflow_id] + all_ids = list(workflow_ids) if delete_children: - workflow_ids.extend(dbos._sys_db.get_workflow_children(workflow_id)) - dbos._sys_db.delete_workflows(workflow_ids) + for wfid in workflow_ids: + all_ids.extend(dbos._sys_db.get_workflow_children(wfid)) + dbos._sys_db.delete_workflows(all_ids) if dbos._app_db: - dbos._app_db.delete_transaction_outputs(workflow_ids) + dbos._app_db.delete_transaction_outputs(all_ids) def garbage_collect( diff --git a/tests/test_async_workflow_management.py b/tests/test_async_workflow_management.py index 2651cdb7..3e7f9819 100644 --- a/tests/test_async_workflow_management.py +++ b/tests/test_async_workflow_management.py @@ -7,6 +7,7 @@ from dbos import DBOS, Queue, SetWorkflowID from dbos._error import DBOSAwaitedWorkflowCancelledError from dbos._sys_db import StepInfo, WorkflowStatus +from dbos._utils import INTERNAL_QUEUE_NAME from tests.conftest import queue_entries_are_cleaned_up @@ -245,3 +246,65 @@ async def simple_workflow(x: int) -> int: step_names = {step["function_name"] for step in steps} assert any("step_one" in name for name in step_names) assert any("step_two" in name for name in step_names) + + +@pytest.mark.asyncio +async def test_bulk_async_workflow_management(dbos: DBOS) -> None: + """Test that all bulk async workflow management methods work and are checkpointed.""" + + @DBOS.workflow() + async def target_workflow(x: int) -> int: + return x + + # Create target workflows + cancel_ids: list[str] = [] + resume_ids: list[str] = [] + delete_ids: list[str] = [] + for i in range(2): + for ids in [cancel_ids, resume_ids, delete_ids]: + wfid = str(uuid.uuid4()) + ids.append(wfid) + with SetWorkflowID(wfid): + h = await DBOS.start_workflow_async(target_workflow, i) + await h.get_result() + + # Pick one workflow to fork + fork_target_id = resume_ids[0] + + @DBOS.workflow() + async def management_workflow() -> None: + await DBOS.cancel_workflows_async(cancel_ids) + handles = await DBOS.resume_workflows_async(resume_ids) + assert len(handles) == 2 + for i, h in enumerate(handles): + assert (await h.get_result()) == i + await DBOS.delete_workflows_async(delete_ids) + workflows = await DBOS.list_workflows_async(workflow_ids=cancel_ids) + assert len(workflows) == 2 + for wf in workflows: + assert wf.workflow_id in cancel_ids + steps = await DBOS.list_workflow_steps_async(fork_target_id) + assert isinstance(steps, list) + forked = await DBOS.fork_workflow_async(fork_target_id, 1) + assert (await forked.get_result()) == 0 + + mgmt_wfid = str(uuid.uuid4()) + with SetWorkflowID(mgmt_wfid): + handle = await DBOS.start_workflow_async(management_workflow) + await handle.get_result() + + # Verify the management operations are checkpointed as steps + steps = await DBOS.list_workflow_steps_async(mgmt_wfid) + step_names = [step["function_name"] for step in steps] + assert "DBOS.cancelWorkflow" in step_names + assert "DBOS.resumeWorkflow" in step_names + assert "DBOS.deleteWorkflow" in step_names + assert "DBOS.listWorkflows" in step_names + assert "DBOS.listWorkflowSteps" in step_names + assert "DBOS.forkWorkflow" in step_names + + # Verify delete actually took effect + for did in delete_ids: + assert await DBOS.get_workflow_status_async(did) is None + + assert queue_entries_are_cleaned_up(dbos) diff --git a/tests/test_workflow_management.py b/tests/test_workflow_management.py index 4010f019..eea6ef1d 100644 --- a/tests/test_workflow_management.py +++ b/tests/test_workflow_management.py @@ -5,7 +5,7 @@ import pytest import sqlalchemy as sa -from dbos import DBOS, Queue, SetWorkflowID, WorkflowHandle +from dbos import DBOS, DBOSClient, Queue, SetWorkflowID, WorkflowHandle from dbos._error import DBOSAwaitedWorkflowCancelledError from dbos._schemas.application_database import ApplicationSchema from dbos._utils import INTERNAL_QUEUE_NAME, GlobalParams @@ -149,6 +149,144 @@ def parent_workflow(x: int) -> int: DBOS.delete_workflow(parent_wfid2, delete_children=False) +def test_bulk_cancel(dbos: DBOS) -> None: + steps_completed = 0 + workflow_events: dict[str, threading.Event] = {} + main_events: dict[str, threading.Event] = {} + + @DBOS.step() + def step_one() -> None: + nonlocal steps_completed + steps_completed += 1 + + @DBOS.step() + def step_two() -> None: + nonlocal steps_completed + steps_completed += 1 + + @DBOS.workflow() + def blocking_workflow() -> str: + wfid = DBOS.workflow_id + assert wfid is not None + step_one() + main_events[wfid].set() + workflow_events[wfid].wait() + step_two() + return wfid + + # Start three workflows, wait for each to reach its blocking point + wfids: list[str] = [] + handles = [] + for _ in range(3): + wfid = str(uuid.uuid4()) + wfids.append(wfid) + workflow_events[wfid] = threading.Event() + main_events[wfid] = threading.Event() + with SetWorkflowID(wfid): + handles.append(DBOS.start_workflow(blocking_workflow)) + main_events[wfid].wait() + + assert steps_completed == 3 + + # Bulk cancel all three workflows at once + DBOS.cancel_workflows(wfids) + + # Release all workflows so they can observe cancellation + for evt in workflow_events.values(): + evt.set() + + for handle in handles: + with pytest.raises(DBOSAwaitedWorkflowCancelledError): + handle.get_result() + + # step_two should not have run for any workflow + assert steps_completed == 3 + + assert queue_entries_are_cleaned_up(dbos) + + +def test_bulk_resume(dbos: DBOS) -> None: + steps_completed = 0 + workflow_events: dict[str, threading.Event] = {} + main_events: dict[str, threading.Event] = {} + + @DBOS.step() + def step_one() -> None: + nonlocal steps_completed + steps_completed += 1 + + @DBOS.step() + def step_two() -> None: + nonlocal steps_completed + steps_completed += 1 + + @DBOS.workflow() + def blocking_workflow(x: int) -> int: + wfid = DBOS.workflow_id + assert wfid is not None + step_one() + main_events[wfid].set() + workflow_events[wfid].wait() + step_two() + return x + + # Start three workflows and cancel them + wfids: list[str] = [] + handles = [] + for i in range(3): + wfid = str(uuid.uuid4()) + wfids.append(wfid) + workflow_events[wfid] = threading.Event() + main_events[wfid] = threading.Event() + with SetWorkflowID(wfid): + handles.append(DBOS.start_workflow(blocking_workflow, i)) + main_events[wfid].wait() + + assert steps_completed == 3 + + DBOS.cancel_workflows(wfids) + for evt in workflow_events.values(): + evt.set() + for handle in handles: + with pytest.raises(DBOSAwaitedWorkflowCancelledError): + handle.get_result() + assert steps_completed == 3 + + # Bulk resume all three workflows + resumed_handles = DBOS.resume_workflows(wfids) + assert len(resumed_handles) == 3 + for i, handle in enumerate(resumed_handles): + assert handle.get_result() == i + assert steps_completed == 6 + + assert queue_entries_are_cleaned_up(dbos) + + +def test_bulk_delete(dbos: DBOS) -> None: + @DBOS.workflow() + def simple_workflow(x: int) -> int: + return x + + # Run three workflows + wfids: list[str] = [] + for i in range(3): + wfid = str(uuid.uuid4()) + wfids.append(wfid) + with SetWorkflowID(wfid): + assert simple_workflow(i) == i + + # Verify all exist + for wfid in wfids: + assert DBOS.get_workflow_status(wfid) is not None + + # Bulk delete all three + DBOS.delete_workflows(wfids) + + # Verify all are gone + for wfid in wfids: + assert DBOS.get_workflow_status(wfid) is None + + def test_cancel_resume_txn(dbos: DBOS) -> None: txn_completed = 0 workflow_event = threading.Event() @@ -882,3 +1020,129 @@ def workflow() -> str: for handle in [fork_one, fork_two, fork_three, fork_four, fork_five]: assert handle.get_result() assert list(DBOS.read_stream(handle.workflow_id, key)) == [0, 1, 2] + + +def test_get_all_events(dbos: DBOS) -> None: + @DBOS.workflow() + def event_workflow() -> str: + DBOS.set_event("key1", "value1") + DBOS.set_event("key2", 42) + DBOS.set_event("key1", "updated") + return DBOS.workflow_id # type: ignore + + handle = DBOS.start_workflow(event_workflow) + wfid = handle.get_result() + + events = dbos._sys_db.get_all_events(wfid) + assert events == {"key1": "updated", "key2": 42} + + # Empty workflow has no events + empty_events = dbos._sys_db.get_all_events("nonexistent") + assert empty_events == {} + + +def test_get_all_notifications(dbos: DBOS) -> None: + recv_event = threading.Event() + + @DBOS.workflow() + def receiver_workflow() -> str: + DBOS.recv(topic="topic_a") + DBOS.recv(topic="topic_b") + recv_event.set() + return DBOS.workflow_id # type: ignore + + wfid = str(uuid.uuid4()) + with SetWorkflowID(wfid): + handle = DBOS.start_workflow(receiver_workflow) + + # Send messages to the receiver workflow (two consumed, one unconsumed) + DBOS.send(wfid, "hello", topic="topic_a") + DBOS.send(wfid, {"data": 123}, topic="topic_b") + recv_event.wait() + handle.get_result() + + # Send an extra message that the workflow never consumes + DBOS.send(wfid, "unconsumed", topic="topic_c") + + notifications = dbos._sys_db.get_all_notifications(wfid) + assert len(notifications) == 3 + assert notifications[0]["topic"] == "topic_a" + assert notifications[0]["message"] == "hello" + assert notifications[0]["consumed"] is True + assert notifications[1]["topic"] == "topic_b" + assert notifications[1]["message"] == {"data": 123} + assert notifications[1]["consumed"] is True + assert notifications[2]["topic"] == "topic_c" + assert notifications[2]["message"] == "unconsumed" + assert notifications[2]["consumed"] is False + + # Nonexistent workflow has no notifications + assert dbos._sys_db.get_all_notifications("nonexistent") == [] + + +def test_get_all_notifications_null_topic(dbos: DBOS) -> None: + recv_event = threading.Event() + + @DBOS.workflow() + def receiver_workflow() -> str: + DBOS.recv() + recv_event.set() + return DBOS.workflow_id # type: ignore + + wfid = str(uuid.uuid4()) + with SetWorkflowID(wfid): + handle = DBOS.start_workflow(receiver_workflow) + + DBOS.send(wfid, "no_topic_msg") + recv_event.wait() + handle.get_result() + + notifications = dbos._sys_db.get_all_notifications(wfid) + assert len(notifications) == 1 + assert notifications[0]["topic"] is None + assert notifications[0]["message"] == "no_topic_msg" + + +def test_get_all_stream_entries(dbos: DBOS) -> None: + @DBOS.workflow() + def stream_workflow() -> str: + DBOS.write_stream("stream_a", 10) + DBOS.write_stream("stream_a", 20) + DBOS.write_stream("stream_b", "hello") + DBOS.close_stream("stream_a") + DBOS.close_stream("stream_b") + return DBOS.workflow_id # type: ignore + + handle = DBOS.start_workflow(stream_workflow) + wfid = handle.get_result() + + streams = dbos._sys_db.get_all_stream_entries(wfid) + assert streams == {"stream_a": [10, 20], "stream_b": ["hello"]} + + # Nonexistent workflow has no streams + assert dbos._sys_db.get_all_stream_entries("nonexistent") == {} + + +def test_client_delete_workflow(client: DBOSClient, dbos: DBOS) -> None: + @DBOS.workflow() + def simple_workflow(x: int) -> int: + return x + + # Test single delete + wfid = str(uuid.uuid4()) + with SetWorkflowID(wfid): + assert simple_workflow(1) == 1 + assert len(client.list_workflows(workflow_ids=[wfid])) == 1 + client.delete_workflow(wfid) + assert len(client.list_workflows(workflow_ids=[wfid])) == 0 + + # Test bulk delete + wfids: list[str] = [] + for i in range(3): + wfid = str(uuid.uuid4()) + wfids.append(wfid) + with SetWorkflowID(wfid): + assert simple_workflow(i) == i + assert len(client.list_workflows(workflow_ids=wfids)) == 3 + client.delete_workflows(wfids) + assert len(client.list_workflows(workflow_ids=wfids)) == 0