Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 42 additions & 38 deletions dbos/system_database.go
Original file line number Diff line number Diff line change
Expand Up @@ -1103,42 +1103,46 @@ func (s *systemDatabase) notificationListenerLoop(ctx context.Context) {
const _DBOS_NULL_TOPIC = "__null__topic__"

// Send is a special type of step that sends a message to another workflow.
// Three differences with a normal steps: durability and the function run in the same transaction, and we forbid nested step execution
// Can be called both within a workflow (as a step) or outside a workflow (directly).
// When called within a workflow: durability and the function run in the same transaction, and we forbid nested step execution
func (s *systemDatabase) Send(ctx context.Context, input WorkflowSendInput) error {
functionName := "DBOS.send"

// Get workflow state from context
// Get workflow state from context (optional for Send as we can send from outside a workflow)
wfState, ok := ctx.Value(workflowStateKey).(*workflowState)
if !ok || wfState == nil {
return newStepExecutionError("", functionName, "workflow state not found in context: are you running this step within a workflow?")
}
var stepID int
var isInWorkflow bool

if wfState.isWithinStep {
return newStepExecutionError(wfState.workflowID, functionName, "cannot call Send within a step")
if ok && wfState != nil {
isInWorkflow = true
if wfState.isWithinStep {
return newStepExecutionError(wfState.workflowID, functionName, "cannot call Send within a step")
}
stepID = wfState.NextStepID()
}

stepID := wfState.NextStepID()

tx, err := s.pool.Begin(ctx)
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
defer tx.Rollback(ctx)

// Check if operation was already executed and do nothing if so
checkInput := checkOperationExecutionDBInput{
workflowID: wfState.workflowID,
stepID: stepID,
stepName: functionName,
tx: tx,
}
recordedResult, err := s.CheckOperationExecution(ctx, checkInput)
if err != nil {
return err
}
if recordedResult != nil {
// when hitting this case, recordedResult will be &{<nil> <nil>}
return nil
// Check if operation was already executed and do nothing if so (only if in workflow)
if isInWorkflow {
checkInput := checkOperationExecutionDBInput{
workflowID: wfState.workflowID,
stepID: stepID,
stepName: functionName,
tx: tx,
}
recordedResult, err := s.CheckOperationExecution(ctx, checkInput)
if err != nil {
return err
}
if recordedResult != nil {
// when hitting this case, recordedResult will be &{<nil> <nil>}
return nil
}
}

// Set default topic if not provided
Expand All @@ -1153,9 +1157,7 @@ func (s *systemDatabase) Send(ctx context.Context, input WorkflowSendInput) erro
return fmt.Errorf("failed to serialize message: %w", err)
}

insertQuery := `INSERT INTO dbos.notifications (destination_uuid, topic, message)
VALUES ($1, $2, $3)`

insertQuery := `INSERT INTO dbos.notifications (destination_uuid, topic, message) VALUES ($1, $2, $3)`
_, err = tx.Exec(ctx, insertQuery, input.DestinationID, topic, messageString)
if err != nil {
// Check for foreign key violation (destination workflow doesn't exist)
Expand All @@ -1165,19 +1167,21 @@ func (s *systemDatabase) Send(ctx context.Context, input WorkflowSendInput) erro
return fmt.Errorf("failed to insert notification: %w", err)
}

// Record the operation result
recordInput := recordOperationResultDBInput{
workflowID: wfState.workflowID,
stepID: stepID,
stepName: functionName,
output: nil,
err: nil,
tx: tx,
}
// Record the operation result if this is called within a workflow
if isInWorkflow {
recordInput := recordOperationResultDBInput{
workflowID: wfState.workflowID,
stepID: stepID,
stepName: functionName,
output: nil,
err: nil,
tx: tx,
}

err = s.RecordOperationResult(ctx, recordInput)
if err != nil {
return fmt.Errorf("failed to record operation result: %w", err)
err = s.RecordOperationResult(ctx, recordInput)
if err != nil {
return fmt.Errorf("failed to record operation result: %w", err)
}
}

// Commit transaction
Expand Down
101 changes: 83 additions & 18 deletions dbos/workflows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -973,6 +973,7 @@ var (
recvIdempotencyWf = WithWorkflow(receiveIdempotencyWorkflow)
receiveIdempotencyStartEvent = NewEvent()
receiveIdempotencyStopEvent = NewEvent()
sendWithinStepWf = WithWorkflow(workflowThatCallsSendInStep)
numConcurrentRecvWfs = 5
concurrentRecvReadyEvents = make([]*Event, numConcurrentRecvWfs)
concurrentRecvStartEvent = NewEvent()
Expand Down Expand Up @@ -1063,6 +1064,22 @@ func receiveIdempotencyWorkflow(ctx context.Context, topic string) (string, erro
return msg, nil
}

func stepThatCallsSend(ctx context.Context, input sendWorkflowInput) (string, error) {
err := Send(ctx, WorkflowSendInput{
DestinationID: input.DestinationID,
Topic: input.Topic,
Message: "message-from-step",
})
if err != nil {
return "", err
}
return "send-completed", nil
}

func workflowThatCallsSendInStep(ctx context.Context, input sendWorkflowInput) (string, error) {
return RunAsStep(ctx, stepThatCallsSend, input)
}

type sendRecvType struct {
Value string
}
Expand Down Expand Up @@ -1185,13 +1202,13 @@ func TestSendRecv(t *testing.T) {
}
})

t.Run("SendRecvMustRunInsideWorkflows", func(t *testing.T) {
t.Run("RecvMustRunInsideWorkflows", func(t *testing.T) {
ctx := context.Background()

// Attempt to run Send outside of a workflow context
err := Send(ctx, WorkflowSendInput{DestinationID: "test-id", Topic: "test-topic", Message: "test-message"})
// Attempt to run Recv outside of a workflow context
_, err := Recv[string](ctx, WorkflowRecvInput{Topic: "test-topic", Timeout: 1 * time.Second})
if err == nil {
t.Fatal("expected error when running Send outside of workflow context, but got none")
t.Fatal("expected error when running Recv outside of workflow context, but got none")
}

// Check the error type
Expand All @@ -1209,26 +1226,35 @@ func TestSendRecv(t *testing.T) {
if !strings.Contains(err.Error(), expectedMessagePart) {
t.Fatalf("expected error message to contain %q, but got %q", expectedMessagePart, err.Error())
}
})

// Attempt to run Recv outside of a workflow context
_, err = Recv[string](ctx, WorkflowRecvInput{Topic: "test-topic", Timeout: 1 * time.Second})
if err == nil {
t.Fatal("expected error when running Recv outside of workflow context, but got none")
t.Run("SendOutsideWorkflow", func(t *testing.T) {
// Start a receive workflow to have a valid destination
receiveHandle, err := receiveWf(context.Background(), "outside-workflow-topic")
if err != nil {
t.Fatalf("failed to start receive workflow: %v", err)
}

// Check the error type
dbosErr, ok = err.(*DBOSError)
if !ok {
t.Fatalf("expected error to be of type *DBOSError, got %T", err)
// Send messages from outside a workflow context (should work now)
ctx := context.Background()
for i := range 3 {
err = Send(ctx, WorkflowSendInput{
DestinationID: receiveHandle.GetWorkflowID(),
Topic: "outside-workflow-topic",
Message: fmt.Sprintf("message%d", i+1),
})
if err != nil {
t.Fatalf("failed to send message%d from outside workflow: %v", i+1, err)
}
}

if dbosErr.Code != StepExecutionError {
t.Fatalf("expected error code to be StepExecutionError, got %v", dbosErr.Code)
// Verify the receive workflow gets all messages
result, err := receiveHandle.GetResult(context.Background())
if err != nil {
t.Fatalf("failed to get result from receive workflow: %v", err)
}

// Test the specific message from the error
if !strings.Contains(err.Error(), expectedMessagePart) {
t.Fatalf("expected error message to contain %q, but got %q", expectedMessagePart, err.Error())
if result != "message1-message2-message3" {
t.Fatalf("expected result to be 'message1-message2-message3', got '%s'", result)
}
})
t.Run("SendRecvIdempotency", func(t *testing.T) {
Expand Down Expand Up @@ -1292,6 +1318,45 @@ func TestSendRecv(t *testing.T) {
}
})

t.Run("SendCannotBeCalledWithinStep", func(t *testing.T) {
// Start a receive workflow to have a valid destination
receiveHandle, err := receiveWf(context.Background(), "send-within-step-topic")
if err != nil {
t.Fatalf("failed to start receive workflow: %v", err)
}

// Execute the workflow that tries to call Send within a step
handle, err := sendWithinStepWf(context.Background(), sendWorkflowInput{
DestinationID: receiveHandle.GetWorkflowID(),
Topic: "send-within-step-topic",
})
if err != nil {
t.Fatalf("failed to start workflow: %v", err)
}

// Expect the workflow to fail with the specific error
_, err = handle.GetResult(context.Background())
if err == nil {
t.Fatal("expected error when calling Send within a step, but got none")
}

// Check the error type
dbosErr, ok := err.(*DBOSError)
if !ok {
t.Fatalf("expected error to be of type *DBOSError, got %T", err)
}

if dbosErr.Code != StepExecutionError {
t.Fatalf("expected error code to be StepExecutionError, got %v", dbosErr.Code)
}

// Test the specific message from the error
expectedMessagePart := "cannot call Send within a step"
if !strings.Contains(err.Error(), expectedMessagePart) {
t.Fatalf("expected error message to contain %q, but got %q", expectedMessagePart, err.Error())
}
})

t.Run("ConcurrentRecv", func(t *testing.T) {
// Test concurrent receivers - only 1 should timeout, others should get errors
receiveTopic := "concurrent-recv-topic"
Expand Down
Loading