From c99cf288cf11e3d77051551423ffc0906689da66 Mon Sep 17 00:00:00 2001 From: maxdml Date: Thu, 28 Aug 2025 09:58:35 -0700 Subject: [PATCH] prevent cancel to success trnsitions --- dbos/system_database.go | 9 ++++--- dbos/workflows_test.go | 54 ++++++++++++++++++++++++++++++++++++++++- 2 files changed, 58 insertions(+), 5 deletions(-) diff --git a/dbos/system_database.go b/dbos/system_database.go index 5a99c377..5cde8188 100644 --- a/dbos/system_database.go +++ b/dbos/system_database.go @@ -697,11 +697,12 @@ type updateWorkflowOutcomeDBInput struct { tx pgx.Tx } -// Will evolve as we serialize all output and error types +// updateWorkflowOutcome updates the status, output, and error of a workflow +// Note that transitions from CANCELLED to SUCCESS or ERROR are forbidden func (s *sysDB) updateWorkflowOutcome(ctx context.Context, input updateWorkflowOutcomeDBInput) error { query := `UPDATE dbos.workflow_status SET status = $1, output = $2, error = $3, updated_at = $4, deduplication_id = NULL - WHERE workflow_uuid = $5 AND NOT (status = $6 AND $1 = $7)` + WHERE workflow_uuid = $5 AND NOT (status = $6 AND $1 in ($7, $8))` outputString, err := serialize(input.output) if err != nil { @@ -714,9 +715,9 @@ func (s *sysDB) updateWorkflowOutcome(ctx context.Context, input updateWorkflowO } if input.tx != nil { - _, err = input.tx.Exec(ctx, query, input.status, outputString, errorStr, time.Now().UnixMilli(), input.workflowID, WorkflowStatusCancelled, WorkflowStatusError) + _, err = input.tx.Exec(ctx, query, input.status, outputString, errorStr, time.Now().UnixMilli(), input.workflowID, WorkflowStatusCancelled, WorkflowStatusSuccess, WorkflowStatusError) } else { - _, err = s.pool.Exec(ctx, query, input.status, outputString, errorStr, time.Now().UnixMilli(), input.workflowID, WorkflowStatusCancelled, WorkflowStatusError) + _, err = s.pool.Exec(ctx, query, input.status, outputString, errorStr, time.Now().UnixMilli(), input.workflowID, WorkflowStatusCancelled, WorkflowStatusSuccess, WorkflowStatusError) } if err != nil { diff --git a/dbos/workflows_test.go b/dbos/workflows_test.go index 8fd8d193..7ad31c4a 100644 --- a/dbos/workflows_test.go +++ b/dbos/workflows_test.go @@ -2942,7 +2942,7 @@ func TestWorkflowCancel(t *testing.T) { } RegisterWorkflow(dbosCtx, blockingWorkflow) - t.Run("TestWorkflowCancel", func(t *testing.T) { + t.Run("TestWorkflowCancelWithRecvError", func(t *testing.T) { topic := "cancel-test-topic" // Start the blocking workflow @@ -2971,6 +2971,58 @@ func TestWorkflowCancel(t *testing.T) { require.NoError(t, err, "failed to get workflow status") assert.Equal(t, WorkflowStatusCancelled, status.Status, "expected workflow status to be WorkflowStatusCancelled") }) + + t.Run("TestWorkflowCancelWithSuccess", func(t *testing.T) { + blockingEventNoError := NewEvent() + + // Workflow that waits for an event, then calls Recv(). Does NOT return error when Recv times out + blockingWorkflowNoError := func(ctx DBOSContext, topic string) (string, error) { + // Wait for the event + blockingEventNoError.Wait() + Recv[string](ctx, topic, 5*time.Second) + // Ignore the error + return "", nil + } + RegisterWorkflow(dbosCtx, blockingWorkflowNoError) + + topic := "cancel-no-error-test-topic" + + // Start the blocking workflow + handle, err := RunWorkflow(dbosCtx, blockingWorkflowNoError, topic) + require.NoError(t, err, "failed to start blocking workflow") + + // Cancel the workflow using DBOS.CancelWorkflow + err = CancelWorkflow(dbosCtx, handle.GetWorkflowID()) + require.NoError(t, err, "failed to cancel workflow") + + // Signal the event so the workflow can move on to Recv() + blockingEventNoError.Set() + + // Check the return values of the workflow + // Because this is a direct handle it'll not return an error + result, err := handle.GetResult() + require.NoError(t, err, "expected no error from direct handle") + assert.Equal(t, "", result, "expected empty result from cancelled workflow") + + // Now use a polling handle to get result -- observe the error + pollingHandle, err := RetrieveWorkflow[string](dbosCtx, handle.GetWorkflowID()) + require.NoError(t, err, "failed to retrieve workflow with polling handle") + + result, err = pollingHandle.GetResult() + require.Error(t, err, "expected error from cancelled workflow even when workflow returns success") + assert.Equal(t, "", result, "expected empty result from cancelled workflow") + + // Check that we still get a DBOSError with AwaitedWorkflowCancelled code + // The gate prevents CANCELLED -> SUCCESS transition + var dbosErr *DBOSError + require.ErrorAs(t, err, &dbosErr, "expected error to be of type *DBOSError, got %T", err) + assert.Equal(t, AwaitedWorkflowCancelled, dbosErr.Code, "expected AwaitedWorkflowCancelled error code, got: %v", dbosErr.Code) + + // Ensure the workflow status remains CANCELLED + status, err := handle.GetStatus() + require.NoError(t, err, "failed to get workflow status") + assert.Equal(t, WorkflowStatusCancelled, status.Status, "expected workflow status to remain WorkflowStatusCancelled due to gate") + }) } var cancelAllBeforeBlockEvent = NewEvent()