Skip to content

Commit ba0ddbb

Browse files
committed
test send within a step is forbidden (if send is done within a workflow)
1 parent 7c7dc42 commit ba0ddbb

File tree

2 files changed

+125
-56
lines changed

2 files changed

+125
-56
lines changed

dbos/system_database.go

Lines changed: 42 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1103,42 +1103,46 @@ func (s *systemDatabase) notificationListenerLoop(ctx context.Context) {
11031103
const _DBOS_NULL_TOPIC = "__null__topic__"
11041104

11051105
// Send is a special type of step that sends a message to another workflow.
1106-
// Three differences with a normal steps: durability and the function run in the same transaction, and we forbid nested step execution
1106+
// Can be called both within a workflow (as a step) or outside a workflow (directly).
1107+
// When called within a workflow: durability and the function run in the same transaction, and we forbid nested step execution
11071108
func (s *systemDatabase) Send(ctx context.Context, input WorkflowSendInput) error {
11081109
functionName := "DBOS.send"
11091110

1110-
// Get workflow state from context
1111+
// Get workflow state from context (optional for Send as we can send from outside a workflow)
11111112
wfState, ok := ctx.Value(workflowStateKey).(*workflowState)
1112-
if !ok || wfState == nil {
1113-
return newStepExecutionError("", functionName, "workflow state not found in context: are you running this step within a workflow?")
1114-
}
1113+
var stepID int
1114+
var isInWorkflow bool
11151115

1116-
if wfState.isWithinStep {
1117-
return newStepExecutionError(wfState.workflowID, functionName, "cannot call Send within a step")
1116+
if ok && wfState != nil {
1117+
isInWorkflow = true
1118+
if wfState.isWithinStep {
1119+
return newStepExecutionError(wfState.workflowID, functionName, "cannot call Send within a step")
1120+
}
1121+
stepID = wfState.NextStepID()
11181122
}
11191123

1120-
stepID := wfState.NextStepID()
1121-
11221124
tx, err := s.pool.Begin(ctx)
11231125
if err != nil {
11241126
return fmt.Errorf("failed to begin transaction: %w", err)
11251127
}
11261128
defer tx.Rollback(ctx)
11271129

1128-
// Check if operation was already executed and do nothing if so
1129-
checkInput := checkOperationExecutionDBInput{
1130-
workflowID: wfState.workflowID,
1131-
stepID: stepID,
1132-
stepName: functionName,
1133-
tx: tx,
1134-
}
1135-
recordedResult, err := s.CheckOperationExecution(ctx, checkInput)
1136-
if err != nil {
1137-
return err
1138-
}
1139-
if recordedResult != nil {
1140-
// when hitting this case, recordedResult will be &{<nil> <nil>}
1141-
return nil
1130+
// Check if operation was already executed and do nothing if so (only if in workflow)
1131+
if isInWorkflow {
1132+
checkInput := checkOperationExecutionDBInput{
1133+
workflowID: wfState.workflowID,
1134+
stepID: stepID,
1135+
stepName: functionName,
1136+
tx: tx,
1137+
}
1138+
recordedResult, err := s.CheckOperationExecution(ctx, checkInput)
1139+
if err != nil {
1140+
return err
1141+
}
1142+
if recordedResult != nil {
1143+
// when hitting this case, recordedResult will be &{<nil> <nil>}
1144+
return nil
1145+
}
11421146
}
11431147

11441148
// Set default topic if not provided
@@ -1153,9 +1157,7 @@ func (s *systemDatabase) Send(ctx context.Context, input WorkflowSendInput) erro
11531157
return fmt.Errorf("failed to serialize message: %w", err)
11541158
}
11551159

1156-
insertQuery := `INSERT INTO dbos.notifications (destination_uuid, topic, message)
1157-
VALUES ($1, $2, $3)`
1158-
1160+
insertQuery := `INSERT INTO dbos.notifications (destination_uuid, topic, message) VALUES ($1, $2, $3)`
11591161
_, err = tx.Exec(ctx, insertQuery, input.DestinationID, topic, messageString)
11601162
if err != nil {
11611163
// Check for foreign key violation (destination workflow doesn't exist)
@@ -1165,19 +1167,21 @@ func (s *systemDatabase) Send(ctx context.Context, input WorkflowSendInput) erro
11651167
return fmt.Errorf("failed to insert notification: %w", err)
11661168
}
11671169

1168-
// Record the operation result
1169-
recordInput := recordOperationResultDBInput{
1170-
workflowID: wfState.workflowID,
1171-
stepID: stepID,
1172-
stepName: functionName,
1173-
output: nil,
1174-
err: nil,
1175-
tx: tx,
1176-
}
1170+
// Record the operation result if this is called within a workflow
1171+
if isInWorkflow {
1172+
recordInput := recordOperationResultDBInput{
1173+
workflowID: wfState.workflowID,
1174+
stepID: stepID,
1175+
stepName: functionName,
1176+
output: nil,
1177+
err: nil,
1178+
tx: tx,
1179+
}
11771180

1178-
err = s.RecordOperationResult(ctx, recordInput)
1179-
if err != nil {
1180-
return fmt.Errorf("failed to record operation result: %w", err)
1181+
err = s.RecordOperationResult(ctx, recordInput)
1182+
if err != nil {
1183+
return fmt.Errorf("failed to record operation result: %w", err)
1184+
}
11811185
}
11821186

11831187
// Commit transaction

dbos/workflows_test.go

Lines changed: 83 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -973,6 +973,7 @@ var (
973973
recvIdempotencyWf = WithWorkflow(receiveIdempotencyWorkflow)
974974
receiveIdempotencyStartEvent = NewEvent()
975975
receiveIdempotencyStopEvent = NewEvent()
976+
sendWithinStepWf = WithWorkflow(workflowThatCallsSendInStep)
976977
numConcurrentRecvWfs = 5
977978
concurrentRecvReadyEvents = make([]*Event, numConcurrentRecvWfs)
978979
concurrentRecvStartEvent = NewEvent()
@@ -1063,6 +1064,22 @@ func receiveIdempotencyWorkflow(ctx context.Context, topic string) (string, erro
10631064
return msg, nil
10641065
}
10651066

1067+
func stepThatCallsSend(ctx context.Context, input sendWorkflowInput) (string, error) {
1068+
err := Send(ctx, WorkflowSendInput{
1069+
DestinationID: input.DestinationID,
1070+
Topic: input.Topic,
1071+
Message: "message-from-step",
1072+
})
1073+
if err != nil {
1074+
return "", err
1075+
}
1076+
return "send-completed", nil
1077+
}
1078+
1079+
func workflowThatCallsSendInStep(ctx context.Context, input sendWorkflowInput) (string, error) {
1080+
return RunAsStep(ctx, stepThatCallsSend, input)
1081+
}
1082+
10661083
type sendRecvType struct {
10671084
Value string
10681085
}
@@ -1185,13 +1202,13 @@ func TestSendRecv(t *testing.T) {
11851202
}
11861203
})
11871204

1188-
t.Run("SendRecvMustRunInsideWorkflows", func(t *testing.T) {
1205+
t.Run("RecvMustRunInsideWorkflows", func(t *testing.T) {
11891206
ctx := context.Background()
11901207

1191-
// Attempt to run Send outside of a workflow context
1192-
err := Send(ctx, WorkflowSendInput{DestinationID: "test-id", Topic: "test-topic", Message: "test-message"})
1208+
// Attempt to run Recv outside of a workflow context
1209+
_, err := Recv[string](ctx, WorkflowRecvInput{Topic: "test-topic", Timeout: 1 * time.Second})
11931210
if err == nil {
1194-
t.Fatal("expected error when running Send outside of workflow context, but got none")
1211+
t.Fatal("expected error when running Recv outside of workflow context, but got none")
11951212
}
11961213

11971214
// Check the error type
@@ -1209,26 +1226,35 @@ func TestSendRecv(t *testing.T) {
12091226
if !strings.Contains(err.Error(), expectedMessagePart) {
12101227
t.Fatalf("expected error message to contain %q, but got %q", expectedMessagePart, err.Error())
12111228
}
1229+
})
12121230

1213-
// Attempt to run Recv outside of a workflow context
1214-
_, err = Recv[string](ctx, WorkflowRecvInput{Topic: "test-topic", Timeout: 1 * time.Second})
1215-
if err == nil {
1216-
t.Fatal("expected error when running Recv outside of workflow context, but got none")
1231+
t.Run("SendOutsideWorkflow", func(t *testing.T) {
1232+
// Start a receive workflow to have a valid destination
1233+
receiveHandle, err := receiveWf(context.Background(), "outside-workflow-topic")
1234+
if err != nil {
1235+
t.Fatalf("failed to start receive workflow: %v", err)
12171236
}
12181237

1219-
// Check the error type
1220-
dbosErr, ok = err.(*DBOSError)
1221-
if !ok {
1222-
t.Fatalf("expected error to be of type *DBOSError, got %T", err)
1238+
// Send messages from outside a workflow context (should work now)
1239+
ctx := context.Background()
1240+
for i := range 3 {
1241+
err = Send(ctx, WorkflowSendInput{
1242+
DestinationID: receiveHandle.GetWorkflowID(),
1243+
Topic: "outside-workflow-topic",
1244+
Message: fmt.Sprintf("message%d", i+1),
1245+
})
1246+
if err != nil {
1247+
t.Fatalf("failed to send message%d from outside workflow: %v", i+1, err)
1248+
}
12231249
}
12241250

1225-
if dbosErr.Code != StepExecutionError {
1226-
t.Fatalf("expected error code to be StepExecutionError, got %v", dbosErr.Code)
1251+
// Verify the receive workflow gets all messages
1252+
result, err := receiveHandle.GetResult(context.Background())
1253+
if err != nil {
1254+
t.Fatalf("failed to get result from receive workflow: %v", err)
12271255
}
1228-
1229-
// Test the specific message from the error
1230-
if !strings.Contains(err.Error(), expectedMessagePart) {
1231-
t.Fatalf("expected error message to contain %q, but got %q", expectedMessagePart, err.Error())
1256+
if result != "message1-message2-message3" {
1257+
t.Fatalf("expected result to be 'message1-message2-message3', got '%s'", result)
12321258
}
12331259
})
12341260
t.Run("SendRecvIdempotency", func(t *testing.T) {
@@ -1292,6 +1318,45 @@ func TestSendRecv(t *testing.T) {
12921318
}
12931319
})
12941320

1321+
t.Run("SendCannotBeCalledWithinStep", func(t *testing.T) {
1322+
// Start a receive workflow to have a valid destination
1323+
receiveHandle, err := receiveWf(context.Background(), "send-within-step-topic")
1324+
if err != nil {
1325+
t.Fatalf("failed to start receive workflow: %v", err)
1326+
}
1327+
1328+
// Execute the workflow that tries to call Send within a step
1329+
handle, err := sendWithinStepWf(context.Background(), sendWorkflowInput{
1330+
DestinationID: receiveHandle.GetWorkflowID(),
1331+
Topic: "send-within-step-topic",
1332+
})
1333+
if err != nil {
1334+
t.Fatalf("failed to start workflow: %v", err)
1335+
}
1336+
1337+
// Expect the workflow to fail with the specific error
1338+
_, err = handle.GetResult(context.Background())
1339+
if err == nil {
1340+
t.Fatal("expected error when calling Send within a step, but got none")
1341+
}
1342+
1343+
// Check the error type
1344+
dbosErr, ok := err.(*DBOSError)
1345+
if !ok {
1346+
t.Fatalf("expected error to be of type *DBOSError, got %T", err)
1347+
}
1348+
1349+
if dbosErr.Code != StepExecutionError {
1350+
t.Fatalf("expected error code to be StepExecutionError, got %v", dbosErr.Code)
1351+
}
1352+
1353+
// Test the specific message from the error
1354+
expectedMessagePart := "cannot call Send within a step"
1355+
if !strings.Contains(err.Error(), expectedMessagePart) {
1356+
t.Fatalf("expected error message to contain %q, but got %q", expectedMessagePart, err.Error())
1357+
}
1358+
})
1359+
12951360
t.Run("ConcurrentRecv", func(t *testing.T) {
12961361
// Test concurrent receivers - only 1 should timeout, others should get errors
12971362
receiveTopic := "concurrent-recv-topic"

0 commit comments

Comments
 (0)