Skip to content

Commit 7c7dc42

Browse files
authored
durable sleep (#40)
1 parent 8f3f20d commit 7c7dc42

File tree

3 files changed

+189
-4
lines changed

3 files changed

+189
-4
lines changed

dbos/system_database.go

Lines changed: 85 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ type SystemDatabase interface {
4242
Recv(ctx context.Context, input WorkflowRecvInput) (any, error)
4343
SetEvent(ctx context.Context, input WorkflowSetEventInput) error
4444
GetEvent(ctx context.Context, input WorkflowGetEventInput) (any, error)
45+
Sleep(ctx context.Context, duration time.Duration) (time.Duration, error)
4546
}
4647

4748
type systemDatabase struct {
@@ -712,10 +713,12 @@ func (s *systemDatabase) RecordOperationResult(ctx context.Context, input record
712713
getLogger().Debug("RecordOperationResult SQL", "sql", commandTag.String())
713714
*/
714715

715-
// TODO return DBOSWorkflowConflictIDError(result["workflow_uuid"]) on 23505 conflict ID error
716716
if err != nil {
717717
getLogger().Error("RecordOperationResult Error occurred", "error", err)
718-
return fmt.Errorf("failed to record operation result: %w", err)
718+
if pgErr, ok := err.(*pgconn.PgError); ok && pgErr.Code == "23505" {
719+
return newWorkflowConflictIDError(input.workflowID)
720+
}
721+
return err
719722
}
720723

721724
if commandTag.RowsAffected() == 0 {
@@ -980,6 +983,86 @@ func (s *systemDatabase) GetWorkflowSteps(ctx context.Context, workflowID string
980983
return steps, nil
981984
}
982985

986+
// Sleep is a special type of step that sleeps for a specified duration
987+
// A wakeup time is computed and recorded in the database
988+
// If we sleep is re-executed, it will only sleep for the remaining duration until the wakeup time
989+
func (s *systemDatabase) Sleep(ctx context.Context, duration time.Duration) (time.Duration, error) {
990+
functionName := "DBOS.sleep"
991+
992+
// Get workflow state from context
993+
wfState, ok := ctx.Value(workflowStateKey).(*workflowState)
994+
if !ok || wfState == nil {
995+
return 0, newStepExecutionError("", functionName, "workflow state not found in context: are you running this step within a workflow?")
996+
}
997+
998+
if wfState.isWithinStep {
999+
return 0, newStepExecutionError(wfState.workflowID, functionName, "cannot call Sleep within a step")
1000+
}
1001+
1002+
stepID := wfState.NextStepID()
1003+
1004+
// Check if operation was already executed
1005+
checkInput := checkOperationExecutionDBInput{
1006+
workflowID: wfState.workflowID,
1007+
stepID: stepID,
1008+
stepName: functionName,
1009+
}
1010+
recordedResult, err := s.CheckOperationExecution(ctx, checkInput)
1011+
if err != nil {
1012+
return 0, fmt.Errorf("failed to check operation execution: %w", err)
1013+
}
1014+
1015+
var endTime time.Time
1016+
1017+
if recordedResult != nil {
1018+
if recordedResult.output == nil { // This should never happen
1019+
return 0, fmt.Errorf("no recorded end time for recorded sleep operation")
1020+
}
1021+
1022+
// The output should be a time.Time representing the end time
1023+
endTimeInterface, ok := recordedResult.output.(time.Time)
1024+
if !ok {
1025+
return 0, fmt.Errorf("recorded output is not a time.Time: %T", recordedResult.output)
1026+
}
1027+
endTime = endTimeInterface
1028+
1029+
if recordedResult.err != nil { // This should never happen
1030+
return 0, recordedResult.err
1031+
}
1032+
} else {
1033+
// First execution: calculate and record the end time
1034+
getLogger().Debug("Durable sleep", "stepID", stepID, "duration", duration)
1035+
1036+
endTime = time.Now().Add(duration)
1037+
1038+
// Record the operation result with the calculated end time
1039+
recordInput := recordOperationResultDBInput{
1040+
workflowID: wfState.workflowID,
1041+
stepID: stepID,
1042+
stepName: functionName,
1043+
output: endTime,
1044+
err: nil,
1045+
}
1046+
1047+
err = s.RecordOperationResult(ctx, recordInput)
1048+
if err != nil {
1049+
// Check if this is a ConflictingWorkflowError (operation already recorded by another process)
1050+
if dbosErr, ok := err.(*DBOSError); ok && dbosErr.Code == ConflictingIDError {
1051+
} else {
1052+
return 0, fmt.Errorf("failed to record sleep operation result: %w", err)
1053+
}
1054+
}
1055+
}
1056+
1057+
// Calculate remaining duration until wake up time
1058+
remainingDuration := max(0, time.Until(endTime))
1059+
1060+
// Actually sleep for the remaining duration
1061+
time.Sleep(remainingDuration)
1062+
1063+
return remainingDuration, nil
1064+
}
1065+
9831066
/****************************************/
9841067
/******* WORKFLOW COMMUNICATIONS ********/
9851068
/****************************************/

dbos/workflow.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -775,6 +775,10 @@ func GetEvent[R any](ctx context.Context, input WorkflowGetEventInput) (R, error
775775
return typedValue, nil
776776
}
777777

778+
func Sleep(ctx context.Context, duration time.Duration) (time.Duration, error) {
779+
return dbos.systemDB.Sleep(ctx, duration)
780+
}
781+
778782
/***********************************/
779783
/******* WORKFLOW MANAGEMENT *******/
780784
/***********************************/

dbos/workflows_test.go

Lines changed: 100 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ func TestSteps(t *testing.T) {
352352
}
353353

354354
// Test the specific message from the 3rd argument
355-
expectedMessagePart := "workflow state not found in context"
355+
expectedMessagePart := "workflow state not found in context: are you running this step within a workflow?"
356356
if !strings.Contains(err.Error(), expectedMessagePart) {
357357
t.Fatalf("expected error message to contain %q, but got %q", expectedMessagePart, err.Error())
358358
}
@@ -1205,7 +1205,7 @@ func TestSendRecv(t *testing.T) {
12051205
}
12061206

12071207
// Test the specific message from the error
1208-
expectedMessagePart := "workflow state not found in context"
1208+
expectedMessagePart := "workflow state not found in context: are you running this step within a workflow?"
12091209
if !strings.Contains(err.Error(), expectedMessagePart) {
12101210
t.Fatalf("expected error message to contain %q, but got %q", expectedMessagePart, err.Error())
12111211
}
@@ -1737,3 +1737,101 @@ func TestSetGetEvent(t *testing.T) {
17371737
}
17381738
})
17391739
}
1740+
1741+
var (
1742+
sleepRecoveryWf = WithWorkflow(sleepRecoveryWorkflow)
1743+
sleepStartEvent *Event
1744+
sleepStopEvent *Event
1745+
)
1746+
1747+
func sleepRecoveryWorkflow(ctx context.Context, duration time.Duration) (time.Duration, error) {
1748+
result, err := Sleep(ctx, duration)
1749+
if err != nil {
1750+
return 0, err
1751+
}
1752+
// Block after sleep so we can recover a pending workflow
1753+
sleepStartEvent.Set()
1754+
sleepStopEvent.Wait()
1755+
return result, nil
1756+
}
1757+
1758+
func TestSleep(t *testing.T) {
1759+
setupDBOS(t)
1760+
1761+
t.Run("SleepDurableRecovery", func(t *testing.T) {
1762+
sleepStartEvent = NewEvent()
1763+
sleepStopEvent = NewEvent()
1764+
1765+
// Start a workflow that sleeps for 2 seconds then blocks
1766+
sleepDuration := 2 * time.Second
1767+
1768+
handle, err := sleepRecoveryWf(context.Background(), sleepDuration)
1769+
if err != nil {
1770+
t.Fatalf("failed to start sleep recovery workflow: %v", err)
1771+
}
1772+
1773+
sleepStartEvent.Wait()
1774+
sleepStartEvent.Clear()
1775+
1776+
// Run the workflow again and check the return time was less than the durable sleep
1777+
startTime := time.Now()
1778+
_, err = sleepRecoveryWf(context.Background(), sleepDuration, WithWorkflowID(handle.GetWorkflowID()))
1779+
if err != nil {
1780+
t.Fatalf("failed to start second sleep recovery workflow: %v", err)
1781+
}
1782+
1783+
sleepStartEvent.Wait()
1784+
// Time elapsed should be at most the sleep duration
1785+
elapsed := time.Since(startTime)
1786+
if elapsed >= sleepDuration {
1787+
t.Fatalf("expected elapsed time to be less than %v, got %v", sleepDuration, elapsed)
1788+
}
1789+
1790+
// Verify the sleep step was recorded correctly
1791+
steps, err := dbos.systemDB.GetWorkflowSteps(context.Background(), handle.GetWorkflowID())
1792+
if err != nil {
1793+
t.Fatalf("failed to get workflow steps: %v", err)
1794+
}
1795+
1796+
if len(steps) != 1 {
1797+
t.Fatalf("expected 1 step (the sleep), got %d", len(steps))
1798+
}
1799+
1800+
step := steps[0]
1801+
if step.FunctionName != "DBOS.sleep" {
1802+
t.Fatalf("expected step name to be 'DBOS.sleep', got '%s'", step.FunctionName)
1803+
}
1804+
1805+
if step.Error != nil {
1806+
t.Fatalf("expected step to have no error, got %v", step.Error)
1807+
}
1808+
1809+
sleepStopEvent.Set()
1810+
})
1811+
1812+
t.Run("SleepCannotBeCalledOutsideWorkflow", func(t *testing.T) {
1813+
ctx := context.Background()
1814+
1815+
// Attempt to call Sleep outside of a workflow context
1816+
_, err := Sleep(ctx, 1*time.Second)
1817+
if err == nil {
1818+
t.Fatal("expected error when calling Sleep outside of workflow context, but got none")
1819+
}
1820+
1821+
// Check the error type
1822+
dbosErr, ok := err.(*DBOSError)
1823+
if !ok {
1824+
t.Fatalf("expected error to be of type *DBOSError, got %T", err)
1825+
}
1826+
1827+
if dbosErr.Code != StepExecutionError {
1828+
t.Fatalf("expected error code to be StepExecutionError, got %v", dbosErr.Code)
1829+
}
1830+
1831+
// Test the specific message from the error
1832+
expectedMessagePart := "workflow state not found in context: are you running this step within a workflow?"
1833+
if !strings.Contains(err.Error(), expectedMessagePart) {
1834+
t.Fatalf("expected error message to contain %q, but got %q", expectedMessagePart, err.Error())
1835+
}
1836+
})
1837+
}

0 commit comments

Comments
 (0)