Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dbos/dbos.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ type DBOSContext interface {
Cancel()

// Workflow operations
RunAsStep(_ DBOSContext, fn StepFunc, input any) (any, error)
RunAsStep(_ DBOSContext, fn StepFunc) (any, error)
RunAsWorkflow(_ DBOSContext, fn WorkflowFunc, input any, opts ...WorkflowOption) (WorkflowHandle[any], error)
Send(_ DBOSContext, input WorkflowSendInputInternal) error
Recv(_ DBOSContext, input WorkflowRecvInput) (any, error)
Expand Down
6 changes: 4 additions & 2 deletions dbos/queues_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,16 @@ This suite tests
*/

func queueWorkflow(ctx DBOSContext, input string) (string, error) {
step1, err := RunAsStep(ctx, queueStep, input)
step1, err := RunAsStep(ctx, func(context context.Context) (string, error) {
return queueStep(context, input)
})
if err != nil {
return "", fmt.Errorf("failed to run step: %v", err)
}
return step1, nil
}

func queueStep(ctx context.Context, input string) (string, error) {
func queueStep(_ context.Context, input string) (string, error) {
return input, nil
}

Expand Down
14 changes: 9 additions & 5 deletions dbos/serialization_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ func encodingStepBuiltinTypes(_ context.Context, input int) (int, error) {
}

func encodingWorkflowBuiltinTypes(ctx DBOSContext, input string) (string, error) {
stepResult, err := RunAsStep(ctx, encodingStepBuiltinTypes, 123)
stepResult, err := RunAsStep(ctx, func(context context.Context) (int, error) {
return encodingStepBuiltinTypes(context, 123)
})
return fmt.Sprintf("%d", stepResult), fmt.Errorf("workflow error: %v", err)
}

Expand All @@ -49,13 +51,15 @@ type SimpleStruct struct {
}

func encodingWorkflowStruct(ctx DBOSContext, input WorkflowInputStruct) (StepOutputStruct, error) {
return RunAsStep(ctx, encodingStepStruct, StepInputStruct{
A: input.A,
B: fmt.Sprintf("%d", input.B),
return RunAsStep(ctx, func(context context.Context) (StepOutputStruct, error) {
return encodingStepStruct(context, StepInputStruct{
A: input.A,
B: fmt.Sprintf("%d", input.B),
})
})
}

func encodingStepStruct(ctx context.Context, input StepInputStruct) (StepOutputStruct, error) {
func encodingStepStruct(_ context.Context, input StepInputStruct) (StepOutputStruct, error) {
return StepOutputStruct{
A: input,
B: "processed by encodingStepStruct",
Expand Down
48 changes: 19 additions & 29 deletions dbos/workflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -684,8 +684,8 @@ func (c *dbosContext) RunAsWorkflow(_ DBOSContext, fn WorkflowFunc, input any, o
/******* STEP FUNCTIONS *******/
/******************************/

type StepFunc func(ctx context.Context, input any) (any, error)
type GenericStepFunc[P any, R any] func(ctx context.Context, input P) (R, error)
type StepFunc func(ctx context.Context) (any, error)
type GenericStepFunc[R any] func(ctx context.Context) (R, error)

const StepParamsKey DBOSContextKey = "stepParams"

Expand Down Expand Up @@ -729,7 +729,7 @@ func setStepParamDefaults(params *StepParams, stepName string) *StepParams {

var typeErasedStepNameToStepName = make(map[string]string)

func RunAsStep[P any, R any](ctx DBOSContext, fn GenericStepFunc[P, R], input P) (R, error) {
func RunAsStep[R any](ctx DBOSContext, fn GenericStepFunc[R]) (R, error) {
if ctx == nil {
return *new(R), newStepExecutionError("", "", "ctx cannot be nil")
}
Expand All @@ -738,46 +738,36 @@ func RunAsStep[P any, R any](ctx DBOSContext, fn GenericStepFunc[P, R], input P)
return *new(R), newStepExecutionError("", "", "step function cannot be nil")
}

// Type-erase the function based on its actual type
typeErasedFn := StepFunc(func(ctx context.Context, input any) (any, error) {
typedInput, ok := input.(P)
if !ok {
return nil, newStepExecutionError("", "", fmt.Sprintf("unexpected input type: expected %T, got %T", *new(P), input))
}
return fn(ctx, typedInput)
})
stepName := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name()

typeErasedStepNameToStepName[runtime.FuncForPC(reflect.ValueOf(typeErasedFn).Pointer()).Name()] = runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name()
// Type-erase the function
typeErasedFn := StepFunc(func(ctx context.Context) (any, error) { return fn(ctx) })
typeErasedStepNameToStepName[runtime.FuncForPC(reflect.ValueOf(typeErasedFn).Pointer()).Name()] = stepName

// Call the executor method
result, err := ctx.RunAsStep(ctx, typeErasedFn, input)
if err != nil {
// In case the errors comes from the DBOS step logic, the result will be nil and we must handle it
if result == nil {
return *new(R), err
}
return result.(R), err
// Call the executor method and pass through the result/error
result, err := ctx.RunAsStep(ctx, typeErasedFn)
// Step function could return a nil result
if result == nil {
return *new(R), err
}

// Type-check and cast the result
// Otherwise type-check and cast the result
typedResult, ok := result.(R)
if !ok {
return *new(R), fmt.Errorf("unexpected result type: expected %T, got %T", *new(R), result)
}

return typedResult, nil
return typedResult, err
}

func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc, input any) (any, error) {
func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc) (any, error) {
// Get workflow state from context
wfState, ok := c.Value(workflowStateKey).(*workflowState)
if !ok || wfState == nil {
// TODO: try to print step name
return nil, newStepExecutionError("", "", "workflow state not found in context: are you running this step within a workflow?")
}

// This should not happen when called from the package-level RunAsStep
if fn == nil {
// TODO: try to print step name
return nil, newStepExecutionError(wfState.workflowID, "", "step function cannot be nil")
}

Expand All @@ -790,7 +780,7 @@ func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc, input any) (any, err

// If within a step, just run the function directly
if wfState.isWithinStep {
return fn(c, input)
return fn(c)
}

// Setup step state
Expand Down Expand Up @@ -819,7 +809,7 @@ func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc, input any) (any, err
// Spawn a child DBOSContext with the step state
stepCtx := WithValue(c, workflowStateKey, &stepState)

stepOutput, stepError := fn(stepCtx, input)
stepOutput, stepError := fn(stepCtx)

// Retry if MaxRetries > 0 and the first execution failed
var joinedErrors error
Expand All @@ -845,7 +835,7 @@ func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc, input any) (any, err
}

// Execute the retry
stepOutput, stepError = fn(stepCtx, input)
stepOutput, stepError = fn(stepCtx)

// If successful, break
if stepError == nil {
Expand Down
80 changes: 55 additions & 25 deletions dbos/workflows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,23 @@ func simpleWorkflowError(dbosCtx DBOSContext, input string) (int, error) {
}

func simpleWorkflowWithStep(dbosCtx DBOSContext, input string) (string, error) {
return RunAsStep(dbosCtx, simpleStep, input)
return RunAsStep(dbosCtx, func(ctx context.Context) (string, error) {
return simpleStep(ctx)
})
}

func simpleStep(ctx context.Context, input string) (string, error) {
func simpleStep(_ context.Context) (string, error) {
return "from step", nil
}

func simpleStepError(ctx context.Context, input string) (string, error) {
func simpleStepError(_ context.Context) (string, error) {
return "", fmt.Errorf("step failure")
}

func simpleWorkflowWithStepError(dbosCtx DBOSContext, input string) (string, error) {
return RunAsStep(dbosCtx, simpleStepError, input)
return RunAsStep(dbosCtx, func(ctx context.Context) (string, error) {
return simpleStepError(ctx)
})
}

// idempotencyWorkflow increments a global counter and returns the input
Expand Down Expand Up @@ -292,38 +296,44 @@ func TestWorkflowsRegistration(t *testing.T) {
}
}

func stepWithinAStep(ctx context.Context, input string) (string, error) {
return simpleStep(ctx, input)
func stepWithinAStep(ctx context.Context) (string, error) {
return simpleStep(ctx)
}

func stepWithinAStepWorkflow(dbosCtx DBOSContext, input string) (string, error) {
return RunAsStep(dbosCtx, stepWithinAStep, input)
return RunAsStep(dbosCtx, func(ctx context.Context) (string, error) {
return stepWithinAStep(ctx)
})
}

// Global counter for retry testing
var stepRetryAttemptCount int

func stepRetryAlwaysFailsStep(ctx context.Context, input string) (string, error) {
func stepRetryAlwaysFailsStep(ctx context.Context) (string, error) {
stepRetryAttemptCount++
return "", fmt.Errorf("always fails - attempt %d", stepRetryAttemptCount)
}

var stepIdempotencyCounter int

func stepIdempotencyTest(ctx context.Context, input int) (string, error) {
func stepIdempotencyTest(ctx context.Context) (string, error) {
stepIdempotencyCounter++
return "", nil
}

func stepRetryWorkflow(dbosCtx DBOSContext, input string) (string, error) {
RunAsStep(dbosCtx, stepIdempotencyTest, 1)
RunAsStep(dbosCtx, func(ctx context.Context) (string, error) {
return stepIdempotencyTest(ctx)
})
stepCtx := WithValue(dbosCtx, StepParamsKey, &StepParams{
MaxRetries: 5,
BaseInterval: 1 * time.Millisecond,
MaxInterval: 10 * time.Millisecond,
})

return RunAsStep(stepCtx, stepRetryAlwaysFailsStep, input)
return RunAsStep(stepCtx, func(ctx context.Context) (string, error) {
return stepRetryAlwaysFailsStep(ctx)
})
}

func TestSteps(t *testing.T) {
Expand All @@ -335,7 +345,9 @@ func TestSteps(t *testing.T) {

t.Run("StepsMustRunInsideWorkflows", func(t *testing.T) {
// Attempt to run a step outside of a workflow context
_, err := RunAsStep(dbosCtx, simpleStep, "test")
_, err := RunAsStep(dbosCtx, func(ctx context.Context) (string, error) {
return simpleStep(ctx)
})
if err == nil {
t.Fatal("expected error when running step outside of workflow context, but got none")
}
Expand Down Expand Up @@ -411,7 +423,7 @@ func TestSteps(t *testing.T) {
}

// Verify the error contains the step name and max retries
expectedErrorMessage := "dbos.stepRetryAlwaysFailsStep has exceeded its maximum of 5 retries"
expectedErrorMessage := "has exceeded its maximum of 5 retries"
if !strings.Contains(dbosErr.Message, expectedErrorMessage) {
t.Fatalf("expected error message to contain '%s', got '%s'", expectedErrorMessage, dbosErr.Message)
}
Expand Down Expand Up @@ -470,7 +482,9 @@ func TestChildWorkflow(t *testing.T) {
return "", fmt.Errorf("expected childWf workflow ID to be %s, got %s", expectedCurrentID, workflowID)
}
// Steps of a child workflow start with an incremented step ID, because the first step ID is allocated to the child workflow
return RunAsStep(dbosCtx, simpleStep, "")
return RunAsStep(dbosCtx, func(ctx context.Context) (string, error) {
return simpleStep(ctx)
})
}
RegisterWorkflow(dbosCtx, childWf)

Expand Down Expand Up @@ -644,7 +658,9 @@ func TestChildWorkflow(t *testing.T) {
customChildID := uuid.NewString()

simpleChildWf := func(dbosCtx DBOSContext, input string) (string, error) {
return RunAsStep(dbosCtx, simpleStep, input)
return RunAsStep(dbosCtx, func(ctx context.Context) (string, error) {
return simpleStep(ctx)
})
}
RegisterWorkflow(dbosCtx, simpleChildWf)

Expand Down Expand Up @@ -713,23 +729,29 @@ func TestChildWorkflow(t *testing.T) {
// Idempotency workflows moved to test functions

func idempotencyWorkflow(dbosCtx DBOSContext, input string) (string, error) {
RunAsStep(dbosCtx, incrementCounter, int64(1))
RunAsStep(dbosCtx, func(ctx context.Context) (int64, error) {
return incrementCounter(ctx, int64(1))
})
return input, nil
}

var blockingStepStopEvent *Event

func blockingStep(ctx context.Context, input string) (string, error) {
func blockingStep(_ context.Context) (string, error) {
blockingStepStopEvent.Wait()
return "", nil
}

var idempotencyWorkflowWithStepEvent *Event

func idempotencyWorkflowWithStep(dbosCtx DBOSContext, input string) (int64, error) {
RunAsStep(dbosCtx, incrementCounter, int64(1))
RunAsStep(dbosCtx, func(ctx context.Context) (int64, error) {
return incrementCounter(ctx, int64(1))
})
idempotencyWorkflowWithStepEvent.Set()
RunAsStep(dbosCtx, blockingStep, input)
RunAsStep(dbosCtx, func(ctx context.Context) (string, error) {
return blockingStep(ctx)
})
return idempotencyCounter, nil
}

Expand Down Expand Up @@ -1131,9 +1153,9 @@ func TestScheduledWorkflows(t *testing.T) {
}

// Stop the workflowScheduler and check if it stops executing
currentCounter := counter
dbosCtx.(*dbosContext).getWorkflowScheduler().Stop()
time.Sleep(3 * time.Second) // Wait a bit to ensure no more executions
currentCounter := counter // If more scheduled executions happen, this can also trigger a data race. If the scheduler is correct, there should be no race.
if counter >= currentCounter+2 {
t.Fatalf("Scheduled workflow continued executing after stopping scheduler: %d (expected < %d)", counter, currentCounter+2)
}
Expand Down Expand Up @@ -1253,7 +1275,9 @@ func stepThatCallsSend(ctx context.Context, input sendWorkflowInput) (string, er
}

func workflowThatCallsSendInStep(ctx DBOSContext, input sendWorkflowInput) (string, error) {
return RunAsStep(ctx, stepThatCallsSend, input)
return RunAsStep(ctx, func(context context.Context) (string, error) {
return stepThatCallsSend(context, input)
})
}

type sendRecvType struct {
Expand Down Expand Up @@ -2193,7 +2217,7 @@ func TestWorkflowTimeout(t *testing.T) {
}
})

waitForCancelStep := func(ctx context.Context, _ string) (string, error) {
waitForCancelStep := func(ctx context.Context) (string, error) {
// This step will trigger cancellation of the entire workflow context
<-ctx.Done()
if !errors.Is(ctx.Err(), context.Canceled) && !errors.Is(ctx.Err(), context.DeadlineExceeded) {
Expand All @@ -2203,7 +2227,9 @@ func TestWorkflowTimeout(t *testing.T) {
}

waitForCancelWorkflowWithStep := func(ctx DBOSContext, _ string) (string, error) {
return RunAsStep(ctx, waitForCancelStep, "trigger-cancellation")
return RunAsStep(ctx, func(context context.Context) (string, error) {
return waitForCancelStep(context)
})
}
RegisterWorkflow(dbosCtx, waitForCancelWorkflowWithStep)

Expand Down Expand Up @@ -2240,7 +2266,9 @@ func TestWorkflowTimeout(t *testing.T) {
// The timeout will trigger a step error, the workflow can do whatever it wants with that error
stepCtx, stepCancelFunc := WithTimeout(ctx, 1*time.Millisecond)
defer stepCancelFunc() // Ensure we clean up the context
_, err := RunAsStep(stepCtx, waitForCancelStep, "short-step-timeout")
_, err := RunAsStep(stepCtx, func(context context.Context) (string, error) {
return waitForCancelStep(context)
})
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("expected step to timeout, got: %v", err)
}
Expand Down Expand Up @@ -2287,7 +2315,9 @@ func TestWorkflowTimeout(t *testing.T) {
// This workflow will run a step that is not cancelable.
// What this means is the workflow *will* be cancelled, but the step will run normally
stepCtx := WithoutCancel(ctx)
res, err := RunAsStep(stepCtx, detachedStep, timeout*2)
res, err := RunAsStep(stepCtx, func(context context.Context) (string, error) {
return detachedStep(context, timeout*2)
})
if err != nil {
t.Fatalf("failed to run detached step: %v", err)
}
Expand Down
Loading