Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 52 additions & 2 deletions dbos/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,13 +364,47 @@ async def get_event_async(
)

def cancel_workflow(self, workflow_id: str) -> None:
self._sys_db.cancel_workflow(workflow_id)
self._sys_db.cancel_workflows([workflow_id])

async def cancel_workflow_async(self, workflow_id: str) -> None:
await asyncio.to_thread(self.cancel_workflow, workflow_id)

def cancel_workflows(self, workflow_ids: List[str]) -> None:
self._sys_db.cancel_workflows(workflow_ids)

async def cancel_workflows_async(self, workflow_ids: List[str]) -> None:
await asyncio.to_thread(self._sys_db.cancel_workflows, workflow_ids)

def delete_workflow(
self, workflow_id: str, *, delete_children: bool = False
) -> None:
self.delete_workflows([workflow_id], delete_children=delete_children)

async def delete_workflow_async(
self, workflow_id: str, *, delete_children: bool = False
) -> None:
await asyncio.to_thread(
self.delete_workflows, [workflow_id], delete_children=delete_children
)

def delete_workflows(
self, workflow_ids: List[str], *, delete_children: bool = False
) -> None:
all_ids = list(workflow_ids)
if delete_children:
for wfid in workflow_ids:
all_ids.extend(self._sys_db.get_workflow_children(wfid))
self._sys_db.delete_workflows(all_ids)

async def delete_workflows_async(
self, workflow_ids: List[str], *, delete_children: bool = False
) -> None:
await asyncio.to_thread(
self.delete_workflows, workflow_ids, delete_children=delete_children
)

def resume_workflow(self, workflow_id: str) -> "WorkflowHandle[Any]":
self._sys_db.resume_workflow(workflow_id)
self._sys_db.resume_workflows([workflow_id])
return WorkflowHandleClientPolling[Any](workflow_id, self._sys_db)

async def resume_workflow_async(
Expand All @@ -379,6 +413,22 @@ async def resume_workflow_async(
await asyncio.to_thread(self.resume_workflow, workflow_id)
return WorkflowHandleClientAsyncPolling[Any](workflow_id, self._sys_db)

def resume_workflows(self, workflow_ids: List[str]) -> "List[WorkflowHandle[Any]]":
self._sys_db.resume_workflows(workflow_ids)
return [
WorkflowHandleClientPolling[Any](wfid, self._sys_db)
for wfid in workflow_ids
]

async def resume_workflows_async(
self, workflow_ids: List[str]
) -> "List[WorkflowHandleAsync[Any]]":
await asyncio.to_thread(self._sys_db.resume_workflows, workflow_ids)
return [
WorkflowHandleClientAsyncPolling[Any](wfid, self._sys_db)
for wfid in workflow_ids
]

def list_workflows(
self,
*,
Expand Down
98 changes: 91 additions & 7 deletions dbos/_conductor/conductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import traceback
from datetime import datetime
from importlib.metadata import version
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, List, Optional

from websockets import ConnectionClosed, ConnectionClosedOK, InvalidStatus
from websockets.sync.client import connect
Expand Down Expand Up @@ -146,11 +146,14 @@ def run(self) -> None:
websocket.send(recovery_response.to_json())
elif msg_type == p.MessageType.CANCEL:
cancel_message = p.CancelRequest.from_json(message)
cancel_ids = cancel_message.workflow_ids or [
cancel_message.workflow_id
]
success = True
try:
self.dbos.cancel_workflow(cancel_message.workflow_id)
self.dbos.cancel_workflows(cancel_ids)
except Exception as e:
error_message = f"Exception encountered when cancelling workflow {cancel_message.workflow_id}: {traceback.format_exc()}"
error_message = f"Exception encountered when cancelling workflow(s) {cancel_ids}: {traceback.format_exc()}"
self.dbos.logger.error(error_message)
success = False
cancel_response = p.CancelResponse(
Expand All @@ -162,15 +165,18 @@ def run(self) -> None:
websocket.send(cancel_response.to_json())
elif msg_type == p.MessageType.DELETE:
delete_message = p.DeleteRequest.from_json(message)
delete_ids = delete_message.workflow_ids or [
delete_message.workflow_id
]
success = True
try:
delete_workflow(
self.dbos,
delete_message.workflow_id,
delete_ids,
delete_children=delete_message.delete_children,
)
except Exception as e:
error_message = f"Exception encountered when deleting workflow {delete_message.workflow_id}: {traceback.format_exc()}"
error_message = f"Exception encountered when deleting workflow(s) {delete_ids}: {traceback.format_exc()}"
self.dbos.logger.error(error_message)
success = False
delete_response = p.DeleteResponse(
Expand All @@ -182,11 +188,14 @@ def run(self) -> None:
websocket.send(delete_response.to_json())
elif msg_type == p.MessageType.RESUME:
resume_message = p.ResumeRequest.from_json(message)
resume_ids = resume_message.workflow_ids or [
resume_message.workflow_id
]
success = True
try:
self.dbos.resume_workflow(resume_message.workflow_id)
self.dbos.resume_workflows(resume_ids)
except Exception as e:
error_message = f"Exception encountered when resuming workflow {resume_message.workflow_id}: {traceback.format_exc()}"
error_message = f"Exception encountered when resuming workflow(s) {resume_ids}: {traceback.format_exc()}"
self.dbos.logger.error(error_message)
success = False
resume_response = p.ResumeResponse(
Expand Down Expand Up @@ -357,6 +366,81 @@ def run(self) -> None:
error_message=error_message,
)
websocket.send(get_workflow_response.to_json())
elif msg_type == p.MessageType.GET_WORKFLOW_EVENTS:
events_message = p.GetWorkflowEventsRequest.from_json(
message
)
event_outputs: Optional[List[p.EventOutput]] = None
error_message = None
try:
raw_events = self.dbos._sys_db.get_all_events(
events_message.workflow_id
)
event_outputs = [
p.EventOutput.from_event_data(k, v)
for k, v in raw_events.items()
]
except Exception as e:
error_message = f"Exception encountered when getting events for workflow {events_message.workflow_id}: {traceback.format_exc()}"
self.dbos.logger.error(error_message)
websocket.send(
p.GetWorkflowEventsResponse(
type=p.MessageType.GET_WORKFLOW_EVENTS,
request_id=base_message.request_id,
events=event_outputs,
error_message=error_message,
).to_json()
)
elif msg_type == p.MessageType.GET_WORKFLOW_NOTIFICATIONS:
notif_message = p.GetWorkflowNotificationsRequest.from_json(
message
)
notif_outputs: Optional[List[p.NotificationOutput]] = None
error_message = None
try:
raw_notifs = self.dbos._sys_db.get_all_notifications(
notif_message.workflow_id
)
notif_outputs = [
p.NotificationOutput.from_notification_info(n)
for n in raw_notifs
]
except Exception as e:
error_message = f"Exception encountered when getting notifications for workflow {notif_message.workflow_id}: {traceback.format_exc()}"
self.dbos.logger.error(error_message)
websocket.send(
p.GetWorkflowNotificationsResponse(
type=p.MessageType.GET_WORKFLOW_NOTIFICATIONS,
request_id=base_message.request_id,
notifications=notif_outputs,
error_message=error_message,
).to_json()
)
elif msg_type == p.MessageType.GET_WORKFLOW_STREAMS:
streams_message = p.GetWorkflowStreamsRequest.from_json(
message
)
stream_outputs: Optional[List[p.StreamEntryOutput]] = None
error_message = None
try:
raw_streams = self.dbos._sys_db.get_all_stream_entries(
streams_message.workflow_id
)
stream_outputs = [
p.StreamEntryOutput.from_stream_data(k, v)
for k, v in raw_streams.items()
]
except Exception as e:
error_message = f"Exception encountered when getting streams for workflow {streams_message.workflow_id}: {traceback.format_exc()}"
self.dbos.logger.error(error_message)
websocket.send(
p.GetWorkflowStreamsResponse(
type=p.MessageType.GET_WORKFLOW_STREAMS,
request_id=base_message.request_id,
streams=stream_outputs,
error_message=error_message,
).to_json()
)
elif msg_type == p.MessageType.EXIST_PENDING_WORKFLOWS:
exist_pending_workflows_message = (
p.ExistPendingWorkflowsRequest.from_json(message)
Expand Down
84 changes: 83 additions & 1 deletion dbos/_conductor/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@
from typing import Dict, List, Optional, Type, TypedDict, TypeVar, Union

from dbos._serialization import Serializer
from dbos._sys_db import StepInfo, VersionInfo, WorkflowSchedule, WorkflowStatus
from dbos._sys_db import (
NotificationInfo,
StepInfo,
VersionInfo,
WorkflowSchedule,
WorkflowStatus,
)


class MessageType(str, Enum):
Expand Down Expand Up @@ -33,6 +39,9 @@ class MessageType(str, Enum):
TRIGGER_SCHEDULE = "trigger_schedule"
LIST_APPLICATION_VERSIONS = "list_application_versions"
SET_LATEST_APPLICATION_VERSION = "set_latest_application_version"
GET_WORKFLOW_EVENTS = "get_workflow_events"
GET_WORKFLOW_NOTIFICATIONS = "get_workflow_notifications"
GET_WORKFLOW_STREAMS = "get_workflow_streams"


T = TypeVar("T", bound="BaseMessage")
Expand Down Expand Up @@ -96,6 +105,7 @@ class RecoveryResponse(BaseMessage):
@dataclass
class CancelRequest(BaseMessage):
workflow_id: str
workflow_ids: Optional[List[str]] = None


@dataclass
Expand All @@ -108,6 +118,7 @@ class CancelResponse(BaseMessage):
class DeleteRequest(BaseMessage):
workflow_id: str
delete_children: bool
workflow_ids: Optional[List[str]] = None


@dataclass
Expand All @@ -119,6 +130,7 @@ class DeleteResponse(BaseMessage):
@dataclass
class ResumeRequest(BaseMessage):
workflow_id: str
workflow_ids: Optional[List[str]] = None


@dataclass
Expand Down Expand Up @@ -587,3 +599,73 @@ class SetLatestApplicationVersionRequest(BaseMessage):
class SetLatestApplicationVersionResponse(BaseMessage):
success: bool
error_message: Optional[str] = None


@dataclass
class NotificationOutput:
topic: Optional[str]
message: str
created_at_epoch_ms: int
consumed: bool

@classmethod
def from_notification_info(cls, info: NotificationInfo) -> "NotificationOutput":
return cls(
topic=info["topic"],
message=str(info["message"]),
created_at_epoch_ms=info["created_at_epoch_ms"],
consumed=info["consumed"],
)


@dataclass
class StreamEntryOutput:
key: str
values: List[str]

@classmethod
def from_stream_data(cls, key: str, values: List[object]) -> "StreamEntryOutput":
return cls(key=key, values=[str(v) for v in values])


@dataclass
class EventOutput:
key: str
value: str

@classmethod
def from_event_data(cls, key: str, value: object) -> "EventOutput":
return cls(key=key, value=str(value))


@dataclass
class GetWorkflowEventsRequest(BaseMessage):
workflow_id: str


@dataclass
class GetWorkflowEventsResponse(BaseMessage):
events: Optional[List[EventOutput]]
error_message: Optional[str] = None


@dataclass
class GetWorkflowNotificationsRequest(BaseMessage):
workflow_id: str


@dataclass
class GetWorkflowNotificationsResponse(BaseMessage):
notifications: Optional[List[NotificationOutput]]
error_message: Optional[str] = None


@dataclass
class GetWorkflowStreamsRequest(BaseMessage):
workflow_id: str


@dataclass
class GetWorkflowStreamsResponse(BaseMessage):
streams: Optional[List[StreamEntryOutput]]
error_message: Optional[str] = None
2 changes: 1 addition & 1 deletion dbos/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def timeout_func() -> None:
was_stopped = evt.wait(time_to_wait_sec)
if was_stopped:
return
dbos._sys_db.cancel_workflow(wfid)
dbos._sys_db.cancel_workflows([wfid])
except Exception as e:
dbos.logger.warning(
f"Exception in timeout thread for workflow {wfid}: {e}"
Expand Down
Loading