Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dbos/dbos.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ type DBOSContext interface {
Cancel()

// Workflow operations
RunAsStep(_ DBOSContext, fn StepFunc, input any) (any, error)
RunAsStep(_ DBOSContext, fn StepFunc) (any, error)
RunAsWorkflow(_ DBOSContext, fn WorkflowFunc, input any, opts ...WorkflowOption) (WorkflowHandle[any], error)
Send(_ DBOSContext, input WorkflowSendInputInternal) error
Recv(_ DBOSContext, input WorkflowRecvInput) (any, error)
Expand Down
6 changes: 4 additions & 2 deletions dbos/queues_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,16 @@ This suite tests
*/

func queueWorkflow(ctx DBOSContext, input string) (string, error) {
step1, err := RunAsStep(ctx, queueStep, input)
step1, err := RunAsStep[string](ctx, func(context context.Context) (string, error) {
return queueStep(context, input)
})
if err != nil {
return "", fmt.Errorf("failed to run step: %v", err)
}
return step1, nil
}

func queueStep(ctx context.Context, input string) (string, error) {
func queueStep(_ context.Context, input string) (string, error) {
return input, nil
}

Expand Down
14 changes: 9 additions & 5 deletions dbos/serialization_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ func encodingStepBuiltinTypes(_ context.Context, input int) (int, error) {
}

func encodingWorkflowBuiltinTypes(ctx DBOSContext, input string) (string, error) {
stepResult, err := RunAsStep(ctx, encodingStepBuiltinTypes, 123)
stepResult, err := RunAsStep[int](ctx, func(context context.Context) (int, error) {
return encodingStepBuiltinTypes(context, 123)
})
return fmt.Sprintf("%d", stepResult), fmt.Errorf("workflow error: %v", err)
}

Expand All @@ -49,13 +51,15 @@ type SimpleStruct struct {
}

func encodingWorkflowStruct(ctx DBOSContext, input WorkflowInputStruct) (StepOutputStruct, error) {
return RunAsStep(ctx, encodingStepStruct, StepInputStruct{
A: input.A,
B: fmt.Sprintf("%d", input.B),
return RunAsStep[StepOutputStruct](ctx, func(context context.Context) (StepOutputStruct, error) {
return encodingStepStruct(context, StepInputStruct{
A: input.A,
B: fmt.Sprintf("%d", input.B),
})
})
}

func encodingStepStruct(ctx context.Context, input StepInputStruct) (StepOutputStruct, error) {
func encodingStepStruct(_ context.Context, input StepInputStruct) (StepOutputStruct, error) {
return StepOutputStruct{
A: input,
B: "processed by encodingStepStruct",
Expand Down
48 changes: 19 additions & 29 deletions dbos/workflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -684,8 +684,8 @@ func (c *dbosContext) RunAsWorkflow(_ DBOSContext, fn WorkflowFunc, input any, o
/******* STEP FUNCTIONS *******/
/******************************/

type StepFunc func(ctx context.Context, input any) (any, error)
type GenericStepFunc[P any, R any] func(ctx context.Context, input P) (R, error)
type StepFunc func(ctx context.Context) (any, error)
type GenericStepFunc[R any] func(ctx context.Context) (R, error)

const StepParamsKey DBOSContextKey = "stepParams"

Expand Down Expand Up @@ -729,7 +729,7 @@ func setStepParamDefaults(params *StepParams, stepName string) *StepParams {

var typeErasedStepNameToStepName = make(map[string]string)

func RunAsStep[P any, R any](ctx DBOSContext, fn GenericStepFunc[P, R], input P) (R, error) {
func RunAsStep[P any, R any](ctx DBOSContext, fn GenericStepFunc[R]) (R, error) {
if ctx == nil {
return *new(R), newStepExecutionError("", "", "ctx cannot be nil")
}
Expand All @@ -738,46 +738,36 @@ func RunAsStep[P any, R any](ctx DBOSContext, fn GenericStepFunc[P, R], input P)
return *new(R), newStepExecutionError("", "", "step function cannot be nil")
}

// Type-erase the function based on its actual type
typeErasedFn := StepFunc(func(ctx context.Context, input any) (any, error) {
typedInput, ok := input.(P)
if !ok {
return nil, newStepExecutionError("", "", fmt.Sprintf("unexpected input type: expected %T, got %T", *new(P), input))
}
return fn(ctx, typedInput)
})
stepName := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name()

typeErasedStepNameToStepName[runtime.FuncForPC(reflect.ValueOf(typeErasedFn).Pointer()).Name()] = runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name()
// Type-erase the function
typeErasedFn := StepFunc(func(ctx context.Context) (any, error) { return fn(ctx) })
typeErasedStepNameToStepName[runtime.FuncForPC(reflect.ValueOf(typeErasedFn).Pointer()).Name()] = stepName

// Call the executor method
result, err := ctx.RunAsStep(ctx, typeErasedFn, input)
if err != nil {
// In case the errors comes from the DBOS step logic, the result will be nil and we must handle it
if result == nil {
return *new(R), err
}
return result.(R), err
// Call the executor method and pass through the result/error
result, err := ctx.RunAsStep(ctx, typeErasedFn)
// Step function could return a nil result
if result == nil {
return *new(R), err
}

// Type-check and cast the result
// Otherwise type-check and cast the result
typedResult, ok := result.(R)
if !ok {
return *new(R), fmt.Errorf("unexpected result type: expected %T, got %T", *new(R), result)
}

return typedResult, nil
return typedResult, err
}

func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc, input any) (any, error) {
func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc) (any, error) {
// Get workflow state from context
wfState, ok := c.Value(workflowStateKey).(*workflowState)
if !ok || wfState == nil {
// TODO: try to print step name
return nil, newStepExecutionError("", "", "workflow state not found in context: are you running this step within a workflow?")
}

// This should not happen when called from the package-level RunAsStep
if fn == nil {
// TODO: try to print step name
return nil, newStepExecutionError(wfState.workflowID, "", "step function cannot be nil")
}

Expand All @@ -790,7 +780,7 @@ func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc, input any) (any, err

// If within a step, just run the function directly
if wfState.isWithinStep {
return fn(c, input)
return fn(c)
}

// Setup step state
Expand Down Expand Up @@ -819,7 +809,7 @@ func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc, input any) (any, err
// Spawn a child DBOSContext with the step state
stepCtx := WithValue(c, workflowStateKey, &stepState)

stepOutput, stepError := fn(stepCtx, input)
stepOutput, stepError := fn(stepCtx)

// Retry if MaxRetries > 0 and the first execution failed
var joinedErrors error
Expand All @@ -845,7 +835,7 @@ func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc, input any) (any, err
}

// Execute the retry
stepOutput, stepError = fn(stepCtx, input)
stepOutput, stepError = fn(stepCtx)

// If successful, break
if stepError == nil {
Expand Down
80 changes: 55 additions & 25 deletions dbos/workflows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,23 @@ func simpleWorkflowError(dbosCtx DBOSContext, input string) (int, error) {
}

func simpleWorkflowWithStep(dbosCtx DBOSContext, input string) (string, error) {
return RunAsStep(dbosCtx, simpleStep, input)
return RunAsStep[string](dbosCtx, func(ctx context.Context) (string, error) {
return simpleStep(ctx)
})
}

func simpleStep(ctx context.Context, input string) (string, error) {
func simpleStep(_ context.Context) (string, error) {
return "from step", nil
}

func simpleStepError(ctx context.Context, input string) (string, error) {
func simpleStepError(_ context.Context) (string, error) {
return "", fmt.Errorf("step failure")
}

func simpleWorkflowWithStepError(dbosCtx DBOSContext, input string) (string, error) {
return RunAsStep(dbosCtx, simpleStepError, input)
return RunAsStep[string](dbosCtx, func(ctx context.Context) (string, error) {
return simpleStepError(ctx)
})
}

// idempotencyWorkflow increments a global counter and returns the input
Expand Down Expand Up @@ -292,38 +296,44 @@ func TestWorkflowsRegistration(t *testing.T) {
}
}

func stepWithinAStep(ctx context.Context, input string) (string, error) {
return simpleStep(ctx, input)
func stepWithinAStep(ctx context.Context) (string, error) {
return simpleStep(ctx)
}

func stepWithinAStepWorkflow(dbosCtx DBOSContext, input string) (string, error) {
return RunAsStep(dbosCtx, stepWithinAStep, input)
return RunAsStep[string](dbosCtx, func(ctx context.Context) (string, error) {
return stepWithinAStep(ctx)
})
}

// Global counter for retry testing
var stepRetryAttemptCount int

func stepRetryAlwaysFailsStep(ctx context.Context, input string) (string, error) {
func stepRetryAlwaysFailsStep(ctx context.Context) (string, error) {
stepRetryAttemptCount++
return "", fmt.Errorf("always fails - attempt %d", stepRetryAttemptCount)
}

var stepIdempotencyCounter int

func stepIdempotencyTest(ctx context.Context, input int) (string, error) {
func stepIdempotencyTest(ctx context.Context) (string, error) {
stepIdempotencyCounter++
return "", nil
}

func stepRetryWorkflow(dbosCtx DBOSContext, input string) (string, error) {
RunAsStep(dbosCtx, stepIdempotencyTest, 1)
RunAsStep[int](dbosCtx, func(ctx context.Context) (string, error) {
return stepIdempotencyTest(ctx)
})
stepCtx := WithValue(dbosCtx, StepParamsKey, &StepParams{
MaxRetries: 5,
BaseInterval: 1 * time.Millisecond,
MaxInterval: 10 * time.Millisecond,
})

return RunAsStep(stepCtx, stepRetryAlwaysFailsStep, input)
return RunAsStep[string](stepCtx, func(ctx context.Context) (string, error) {
return stepRetryAlwaysFailsStep(ctx)
})
}

func TestSteps(t *testing.T) {
Expand All @@ -335,7 +345,9 @@ func TestSteps(t *testing.T) {

t.Run("StepsMustRunInsideWorkflows", func(t *testing.T) {
// Attempt to run a step outside of a workflow context
_, err := RunAsStep(dbosCtx, simpleStep, "test")
_, err := RunAsStep[int](dbosCtx, func(ctx context.Context) (string, error) {
return simpleStep(ctx)
})
if err == nil {
t.Fatal("expected error when running step outside of workflow context, but got none")
}
Expand Down Expand Up @@ -411,7 +423,7 @@ func TestSteps(t *testing.T) {
}

// Verify the error contains the step name and max retries
expectedErrorMessage := "dbos.stepRetryAlwaysFailsStep has exceeded its maximum of 5 retries"
expectedErrorMessage := "has exceeded its maximum of 5 retries"
if !strings.Contains(dbosErr.Message, expectedErrorMessage) {
t.Fatalf("expected error message to contain '%s', got '%s'", expectedErrorMessage, dbosErr.Message)
}
Expand Down Expand Up @@ -470,7 +482,9 @@ func TestChildWorkflow(t *testing.T) {
return "", fmt.Errorf("expected childWf workflow ID to be %s, got %s", expectedCurrentID, workflowID)
}
// Steps of a child workflow start with an incremented step ID, because the first step ID is allocated to the child workflow
return RunAsStep(dbosCtx, simpleStep, "")
return RunAsStep[string](dbosCtx, func(ctx context.Context) (string, error) {
return simpleStep(ctx)
})
}
RegisterWorkflow(dbosCtx, childWf)

Expand Down Expand Up @@ -644,7 +658,9 @@ func TestChildWorkflow(t *testing.T) {
customChildID := uuid.NewString()

simpleChildWf := func(dbosCtx DBOSContext, input string) (string, error) {
return RunAsStep(dbosCtx, simpleStep, input)
return RunAsStep[string](dbosCtx, func(ctx context.Context) (string, error) {
return simpleStep(ctx)
})
}
RegisterWorkflow(dbosCtx, simpleChildWf)

Expand Down Expand Up @@ -713,23 +729,29 @@ func TestChildWorkflow(t *testing.T) {
// Idempotency workflows moved to test functions

func idempotencyWorkflow(dbosCtx DBOSContext, input string) (string, error) {
RunAsStep(dbosCtx, incrementCounter, int64(1))
RunAsStep[int64](dbosCtx, func(ctx context.Context) (int64, error) {
return incrementCounter(ctx, int64(1))
})
return input, nil
}

var blockingStepStopEvent *Event

func blockingStep(ctx context.Context, input string) (string, error) {
func blockingStep(_ context.Context) (string, error) {
blockingStepStopEvent.Wait()
return "", nil
}

var idempotencyWorkflowWithStepEvent *Event

func idempotencyWorkflowWithStep(dbosCtx DBOSContext, input string) (int64, error) {
RunAsStep(dbosCtx, incrementCounter, int64(1))
RunAsStep[int64](dbosCtx, func(ctx context.Context) (int64, error) {
return incrementCounter(ctx, int64(1))
})
idempotencyWorkflowWithStepEvent.Set()
RunAsStep(dbosCtx, blockingStep, input)
RunAsStep[int](dbosCtx, func(ctx context.Context) (string, error) {
return blockingStep(ctx)
})
return idempotencyCounter, nil
}

Expand Down Expand Up @@ -1131,9 +1153,9 @@ func TestScheduledWorkflows(t *testing.T) {
}

// Stop the workflowScheduler and check if it stops executing
currentCounter := counter
dbosCtx.(*dbosContext).getWorkflowScheduler().Stop()
time.Sleep(3 * time.Second) // Wait a bit to ensure no more executions
currentCounter := counter // If more scheduled executions happen, this can also trigger a data race. If the scheduler is correct, there should be no race.
if counter >= currentCounter+2 {
t.Fatalf("Scheduled workflow continued executing after stopping scheduler: %d (expected < %d)", counter, currentCounter+2)
}
Expand Down Expand Up @@ -1253,7 +1275,9 @@ func stepThatCallsSend(ctx context.Context, input sendWorkflowInput) (string, er
}

func workflowThatCallsSendInStep(ctx DBOSContext, input sendWorkflowInput) (string, error) {
return RunAsStep(ctx, stepThatCallsSend, input)
return RunAsStep[sendWorkflowInput](ctx, func(context context.Context) (string, error) {
return stepThatCallsSend(context, input)
})
}

type sendRecvType struct {
Expand Down Expand Up @@ -2193,7 +2217,7 @@ func TestWorkflowTimeout(t *testing.T) {
}
})

waitForCancelStep := func(ctx context.Context, _ string) (string, error) {
waitForCancelStep := func(ctx context.Context) (string, error) {
// This step will trigger cancellation of the entire workflow context
<-ctx.Done()
if !errors.Is(ctx.Err(), context.Canceled) && !errors.Is(ctx.Err(), context.DeadlineExceeded) {
Expand All @@ -2203,7 +2227,9 @@ func TestWorkflowTimeout(t *testing.T) {
}

waitForCancelWorkflowWithStep := func(ctx DBOSContext, _ string) (string, error) {
return RunAsStep(ctx, waitForCancelStep, "trigger-cancellation")
return RunAsStep[sendWorkflowInput](ctx, func(context context.Context) (string, error) {
return waitForCancelStep(context)
})
}
RegisterWorkflow(dbosCtx, waitForCancelWorkflowWithStep)

Expand Down Expand Up @@ -2240,7 +2266,9 @@ func TestWorkflowTimeout(t *testing.T) {
// The timeout will trigger a step error, the workflow can do whatever it wants with that error
stepCtx, stepCancelFunc := WithTimeout(ctx, 1*time.Millisecond)
defer stepCancelFunc() // Ensure we clean up the context
_, err := RunAsStep(stepCtx, waitForCancelStep, "short-step-timeout")
_, err := RunAsStep[string](stepCtx, func(context context.Context) (string, error) {
return waitForCancelStep(context)
})
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("expected step to timeout, got: %v", err)
}
Expand Down Expand Up @@ -2287,7 +2315,9 @@ func TestWorkflowTimeout(t *testing.T) {
// This workflow will run a step that is not cancelable.
// What this means is the workflow *will* be cancelled, but the step will run normally
stepCtx := WithoutCancel(ctx)
res, err := RunAsStep(stepCtx, detachedStep, timeout*2)
res, err := RunAsStep[time.Duration](stepCtx, func(context context.Context) (string, error) {
return detachedStep(context, timeout*2)
})
if err != nil {
t.Fatalf("failed to run detached step: %v", err)
}
Expand Down
Loading