Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions dbos/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
DBOSMaxStepRetriesExceeded,
DBOSNonExistentWorkflowError,
DBOSRecoveryError,
DBOSWorkflowCancelledError,
DBOSWorkflowConflictIDError,
DBOSWorkflowFunctionNotFoundError,
)
Expand Down Expand Up @@ -224,6 +225,8 @@ def persist(func: Callable[[], R]) -> R:
)
output = wf_handle.get_result()
return output
except DBOSWorkflowCancelledError as error:
raise
except Exception as error:
status["status"] = "ERROR"
status["error"] = _serialization.serialize_exception(error)
Expand Down Expand Up @@ -539,6 +542,13 @@ def invoke_tx(*args: Any, **kwargs: Any) -> Any:
raise DBOSException(
f"Function {func.__name__} invoked before DBOS initialized"
)

ctx = assert_current_dbos_context()
if dbosreg.is_workflow_cancelled(ctx.workflow_id):
raise DBOSWorkflowCancelledError(
f"Workflow {ctx.workflow_id} is cancelled. Aborting transaction {func.__name__}."
)

dbos = dbosreg.dbos
with dbos._app_db.sessionmaker() as session:
attributes: TracedAttributes = {
Expand All @@ -560,6 +570,12 @@ def invoke_tx(*args: Any, **kwargs: Any) -> Any:
backoff_factor = 1.5
max_retry_wait_seconds = 2.0
while True:

if dbosreg.is_workflow_cancelled(ctx.workflow_id):
raise DBOSWorkflowCancelledError(
f"Workflow {ctx.workflow_id} is cancelled. Aborting transaction {func.__name__}."
)

has_recorded_error = False
txn_error: Optional[Exception] = None
try:
Expand Down Expand Up @@ -710,6 +726,13 @@ def invoke_step(*args: Any, **kwargs: Any) -> Any:
"operationType": OperationType.STEP.value,
}

# Check if the workflow is cancelled
ctx = assert_current_dbos_context()
if dbosreg.is_workflow_cancelled(ctx.workflow_id):
raise DBOSWorkflowCancelledError(
f"Workflow {ctx.workflow_id} is cancelled. Aborting step {func.__name__}."
)

attempts = max_attempts if retries_allowed else 1
max_retry_interval_seconds: float = 3600 # 1 Hour

Expand Down Expand Up @@ -800,6 +823,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
ctx = get_local_dbos_context()
if ctx and ctx.is_step():
# Call the original function directly

return func(*args, **kwargs)
if ctx and ctx.is_within_workflow():
assert ctx.is_workflow(), "Steps must be called from within workflows"
Expand Down
12 changes: 12 additions & 0 deletions dbos/_dbos.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def __init__(self) -> None:
self.pollers: list[RegisteredJob] = []
self.dbos: Optional[DBOS] = None
self.config: Optional[ConfigFile] = None
self.workflow_cancelled_map: dict[str, bool] = {}

def register_wf_function(self, name: str, wrapped_func: F, functype: str) -> None:
if name in self.function_type_map:
Expand Down Expand Up @@ -197,6 +198,15 @@ def register_instance(self, inst: object) -> None:
else:
self.instance_info_map[fn] = inst

def cancel_workflow(self, workflow_id: str) -> None:
self.workflow_cancelled_map[workflow_id] = True

def is_workflow_cancelled(self, workflow_id: str) -> bool:
return self.workflow_cancelled_map.get(workflow_id, False)

def clear_workflow_cancelled(self, workflow_id: str) -> None:
self.workflow_cancelled_map.pop(workflow_id, None)

def compute_app_version(self) -> str:
"""
An application's version is computed from a hash of the source of its workflows.
Expand Down Expand Up @@ -844,11 +854,13 @@ def recover_pending_workflows(
def cancel_workflow(cls, workflow_id: str) -> None:
"""Cancel a workflow by ID."""
_get_dbos_instance()._sys_db.cancel_workflow(workflow_id)
_get_or_create_dbos_registry().cancel_workflow(workflow_id)

@classmethod
def resume_workflow(cls, workflow_id: str) -> WorkflowHandle[Any]:
"""Resume a workflow by ID."""
_get_dbos_instance()._sys_db.resume_workflow(workflow_id)
_get_or_create_dbos_registry().clear_workflow_cancelled(workflow_id)
return execute_workflow_by_id(_get_dbos_instance(), workflow_id, False)

@classproperty
Expand Down
11 changes: 11 additions & 0 deletions dbos/_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class DBOSErrorCode(Enum):
MaxStepRetriesExceeded = 7
NotAuthorized = 8
ConflictingWorkflowError = 9
WorkflowCancelled = 10
ConflictingRegistrationError = 25


Expand Down Expand Up @@ -130,6 +131,16 @@ def __init__(self) -> None:
)


class DBOSWorkflowCancelledError(DBOSException):
"""Exception raised when the workflow has already been cancelled."""

def __init__(self, msg: str) -> None:
super().__init__(
msg,
dbos_error_code=DBOSErrorCode.WorkflowCancelled.value,
)


class DBOSConflictingRegistrationError(DBOSException):
"""Exception raised when conflicting decorators are applied to the same function."""

Expand Down
145 changes: 145 additions & 0 deletions tests/test_workflow_cancel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import threading
import time
import uuid
from datetime import datetime, timedelta, timezone

# Public API
from dbos import (
DBOS,
ConfigFile,
Queue,
SetWorkflowID,
WorkflowStatusString,
_workflow_commands,
)


def test_basic(dbos: DBOS, config: ConfigFile) -> None:

steps_completed = 0

@DBOS.step()
def step_one() -> None:
nonlocal steps_completed
steps_completed += 1
print("Step one completed!")

@DBOS.step()
def step_two() -> None:
nonlocal steps_completed
steps_completed += 1
print("Step two completed!")

@DBOS.workflow()
def simple_workflow() -> None:
step_one()
dbos.sleep(1)
step_two()
print("Executed Simple workflow")
return

# run the workflow
simple_workflow()
time.sleep(1) # wait for the workflow to complete
assert (
steps_completed == 2
), f"Expected steps_completed to be 2, but got {steps_completed}"


def test_two_steps_cancel(dbos: DBOS, config: ConfigFile) -> None:

steps_completed = 0

@DBOS.step()
def step_one() -> None:
nonlocal steps_completed
steps_completed += 1
print("Step one completed!")

@DBOS.step()
def step_two() -> None:
nonlocal steps_completed
steps_completed += 1
print("Step two completed!")

@DBOS.workflow()
def simple_workflow() -> None:
step_one()
dbos.sleep(2)
step_two()
print("Executed Simple workflow")
return

# run the workflow
try:
wfuuid = str(uuid.uuid4())
with SetWorkflowID(wfuuid):
simple_workflow()

dbos.cancel_workflow(wfuuid)
except Exception as e:
# time.sleep(1) # wait for the workflow to complete
assert (
steps_completed == 1
), f"Expected steps_completed to be 1, but got {steps_completed}"

dbos.resume_workflow(wfuuid)
time.sleep(1)

assert (
steps_completed == 2
), f"Expected steps_completed to be 2, but got {steps_completed}"


def test_two_transactions_cancel(dbos: DBOS, config: ConfigFile) -> None:

tr_completed = 0

@DBOS.transaction()
def transaction_one() -> None:
nonlocal tr_completed
tr_completed += 1
print("Transaction one completed!")

@DBOS.transaction()
def transaction_two() -> None:
nonlocal tr_completed
tr_completed += 1
print("Step two completed!")

@DBOS.workflow()
def simple_workflow() -> None:
transaction_one()
dbos.sleep(2)
transaction_two()
print("Executed Simple workflow")
return

# run the workflow
try:
wfuuid = str(uuid.uuid4())
with SetWorkflowID(wfuuid):
simple_workflow()

dbos.cancel_workflow(wfuuid)
except Exception as e:
# time.sleep(1) # wait for the workflow to complete
assert (
tr_completed == 1
), f"Expected tr_completed to be 1, but got {tr_completed}"

dbos.resume_workflow(wfuuid)
time.sleep(1)

assert (
tr_completed == 2
), f"Expected steps_completed to be 2, but got {tr_completed}"

# resume it a 2nd time

dbos.resume_workflow(wfuuid)
time.sleep(1)

assert (
tr_completed == 2
), f"Expected steps_completed to be 2, but got {tr_completed}"