Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dbos/dbos.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ type DBOSContext interface {
Cancel() // Gracefully shutdown the DBOS runtime, waiting for workflows to complete and cleaning up resources

// Workflow operations
RunAsStep(_ DBOSContext, fn StepFunc) (any, error) // Execute a function as a durable step within a workflow
RunAsStep(_ DBOSContext, fn StepFunc, opts ...StepOption) (any, error) // Execute a function as a durable step within a workflow
RunAsWorkflow(_ DBOSContext, fn WorkflowFunc, input any, opts ...WorkflowOption) (WorkflowHandle[any], error) // Start a new workflow execution
Send(_ DBOSContext, input WorkflowSendInput) error // Send a message to another workflow
Recv(_ DBOSContext, input WorkflowRecvInput) (any, error) // Receive a message sent to this workflow
Expand Down
162 changes: 91 additions & 71 deletions dbos/workflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"math"
"reflect"
"runtime"
"sync"
"time"

"github.com/google/uuid"
Expand Down Expand Up @@ -827,70 +826,92 @@ type StepFunc func(ctx context.Context) (any, error)
// GenericStepFunc represents a type-safe step function with a specific output type R.
type GenericStepFunc[R any] func(ctx context.Context) (R, error)

// StepParamsKey is the context key for setting StepParams in a workflow context.
// Use this key with the dbos.WithValue to configure steps.
const StepParamsKey DBOSContextKey = "stepParams"

// StepParams configures retry behavior and identification for step execution.
// These parameters can be set in the context using the StepParamsKey.
type StepParams struct {
// stepOptions holds the configuration for step execution using functional options pattern.
type stepOptions struct {
MaxRetries int // Maximum number of retry attempts (0 = no retries)
BackoffFactor float64 // Exponential backoff multiplier between retries (default: 2.0)
BaseInterval time.Duration // Initial delay between retries (default: 100ms)
MaxInterval time.Duration // Maximum delay between retries (default: 5s)
StepName string // Custom name for the step (defaults to function name)
}

// setStepParamDefaults returns a StepParams struct with all defaults properly set
func setStepParamDefaults(params *StepParams, stepName string) *StepParams {
if params == nil {
return &StepParams{
MaxRetries: 0, // Default to no retries
BackoffFactor: _DEFAULT_STEP_BACKOFF_FACTOR,
BaseInterval: _DEFAULT_STEP_BASE_INTERVAL, // Default base interval
MaxInterval: _DEFAULT_STEP_MAX_INTERVAL, // Default max interval
StepName: func() string {
if value, ok := typeErasedStepNameToStepName.Load(stepName); ok {
return value.(string)
}
return "" // This should never happen
}(),
}
}

// Set defaults for zero values
if params.BackoffFactor == 0 {
params.BackoffFactor = _DEFAULT_STEP_BACKOFF_FACTOR // Default backoff factor
// setDefaults applies default values to stepOptions
func (opts *stepOptions) setDefaults() {
if opts.BackoffFactor == 0 {
opts.BackoffFactor = _DEFAULT_STEP_BACKOFF_FACTOR
}
if params.BaseInterval == 0 {
params.BaseInterval = _DEFAULT_STEP_BASE_INTERVAL // Default base interval
if opts.BaseInterval == 0 {
opts.BaseInterval = _DEFAULT_STEP_BASE_INTERVAL
}
if params.MaxInterval == 0 {
params.MaxInterval = _DEFAULT_STEP_MAX_INTERVAL // Default max interval
if opts.MaxInterval == 0 {
opts.MaxInterval = _DEFAULT_STEP_MAX_INTERVAL
}
if len(params.StepName) == 0 {
// If the step name is not provided, use the function name
if value, ok := typeErasedStepNameToStepName.Load(stepName); ok {
params.StepName = value.(string)
}

// StepOption is a functional option for configuring step execution parameters.
type StepOption func(*stepOptions)

// WithStepName sets a custom name for the step. If the step name has already been set
// by a previous call to WithStepName, this option will be ignored to allow
// multiple WithStepName calls without overriding the first one.
func WithStepName(name string) StepOption {
return func(opts *stepOptions) {
if opts.StepName == "" {
opts.StepName = name
}
}
}

return params
// WithStepMaxRetries sets the maximum number of retry attempts for the step.
// A value of 0 means no retries (default behavior).
func WithStepMaxRetries(maxRetries int) StepOption {
return func(opts *stepOptions) {
opts.MaxRetries = maxRetries
}
}

var typeErasedStepNameToStepName sync.Map
// WithBackoffFactor sets the exponential backoff multiplier between retries.
// The delay between retries is calculated as: BaseInterval * (BackoffFactor^(retry-1))
// Default value is 2.0.
func WithBackoffFactor(factor float64) StepOption {
return func(opts *stepOptions) {
opts.BackoffFactor = factor
}
}

// WithBaseInterval sets the initial delay between retries.
// Default value is 100ms.
func WithBaseInterval(interval time.Duration) StepOption {
return func(opts *stepOptions) {
opts.BaseInterval = interval
}
}

// WithMaxInterval sets the maximum delay between retries.
// Default value is 5s.
func WithMaxInterval(interval time.Duration) StepOption {
return func(opts *stepOptions) {
opts.MaxInterval = interval
}
}

// RunAsStep executes a function as a durable step within a workflow.
// Steps provide at-least-once execution guarantees and automatic retry capabilities.
// If a step has already been executed (e.g., during workflow recovery), its recorded
// result is returned instead of re-executing the function.
//
// Steps can be configured with retry parameters by setting StepParams in the context:
// Steps can be configured with functional options:
//
// stepCtx = context.WithValue(ctx, dbos.StepParamsKey, &dbos.StepParams{
// MaxRetries: 3,
// BaseInterval: 500 * time.Millisecond,
// })
// data, err := dbos.RunAsStep(ctx, func(ctx context.Context) ([]byte, error) {
// return MyStep(ctx, "https://api.example.com/data")
// }, dbos.WithStepMaxRetries(3), dbos.WithBaseInterval(500*time.Millisecond))
//
// Available options:
// - WithStepName: Custom name for the step (only sets if not already set)
// - WithStepMaxRetries: Maximum retry attempts (default: 0)
// - WithBackoffFactor: Exponential backoff multiplier (default: 2.0)
// - WithBaseInterval: Initial delay between retries (default: 100ms)
// - WithMaxInterval: Maximum delay between retries (default: 5s)
//
// Example:
//
Expand All @@ -904,17 +925,17 @@ var typeErasedStepNameToStepName sync.Map
// }
//
// // Within a workflow:
// data, err := dbos.RunAsStep(stepCtx, func(ctx context.Context) ([]byte, error) {
// data, err := dbos.RunAsStep(ctx, func(ctx context.Context) ([]byte, error) {
// return MyStep(ctx, "https://api.example.com/data")
// })
// }, dbos.WithStepName("FetchData"), dbos.WithStepMaxRetries(3))
// if err != nil {
// return nil, err
// }
//
// Note that the function passed to RunAsStep must accept a context.Context as its first parameter
// and this context *must* be the one specified in the function's signature (not the context passed to RunAsStep).
// Under the hood, DBOS will augment the step's context and pass it to the function when executing it durably.
func RunAsStep[R any](ctx DBOSContext, fn GenericStepFunc[R]) (R, error) {
func RunAsStep[R any](ctx DBOSContext, fn GenericStepFunc[R], opts ...StepOption) (R, error) {
if ctx == nil {
return *new(R), newStepExecutionError("", "", "ctx cannot be nil")
}
Expand All @@ -923,15 +944,14 @@ func RunAsStep[R any](ctx DBOSContext, fn GenericStepFunc[R]) (R, error) {
return *new(R), newStepExecutionError("", "", "step function cannot be nil")
}

// Append WithStepName option to ensure the step name is set. This will not erase a user-provided step name
stepName := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name()
opts = append(opts, WithStepName(stepName))

// Type-erase the function
typeErasedFn := StepFunc(func(ctx context.Context) (any, error) { return fn(ctx) })
typeErasedFnName := runtime.FuncForPC(reflect.ValueOf(typeErasedFn).Pointer()).Name()
typeErasedStepNameToStepName.LoadOrStore(typeErasedFnName, stepName)

// Call the executor method and pass through the result/error
result, err := ctx.RunAsStep(ctx, typeErasedFn)
result, err := ctx.RunAsStep(ctx, typeErasedFn, opts...)
// Step function could return a nil result
if result == nil {
return *new(R), err
Expand All @@ -944,23 +964,23 @@ func RunAsStep[R any](ctx DBOSContext, fn GenericStepFunc[R]) (R, error) {
return typedResult, err
}

func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc) (any, error) {
// Look up for step parameters in the context and set defaults
params, ok := c.Value(StepParamsKey).(*StepParams)
if !ok {
params = nil
func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc, opts ...StepOption) (any, error) {
// Process functional options
stepOpts := &stepOptions{}
for _, opt := range opts {
opt(stepOpts)
}
params = setStepParamDefaults(params, runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name())
stepOpts.setDefaults()

// Get workflow state from context
wfState, ok := c.Value(workflowStateKey).(*workflowState)
if !ok || wfState == nil {
return nil, newStepExecutionError("", params.StepName, "workflow state not found in context: are you running this step within a workflow?")
return nil, newStepExecutionError("", stepOpts.StepName, "workflow state not found in context: are you running this step within a workflow?")
}

// This should not happen when called from the package-level RunAsStep
if fn == nil {
return nil, newStepExecutionError(wfState.workflowID, params.StepName, "step function cannot be nil")
return nil, newStepExecutionError(wfState.workflowID, stepOpts.StepName, "step function cannot be nil")
}

// If within a step, just run the function directly
Expand All @@ -982,10 +1002,10 @@ func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc) (any, error) {
recordedOutput, err := c.systemDB.checkOperationExecution(uncancellableCtx, checkOperationExecutionDBInput{
workflowID: stepState.workflowID,
stepID: stepState.stepID,
stepName: params.StepName,
stepName: stepOpts.StepName,
})
if err != nil {
return nil, newStepExecutionError(stepState.workflowID, params.StepName, fmt.Sprintf("checking operation execution: %v", err))
return nil, newStepExecutionError(stepState.workflowID, stepOpts.StepName, fmt.Sprintf("checking operation execution: %v", err))
}
if recordedOutput != nil {
return recordedOutput.output, recordedOutput.err
Expand All @@ -998,23 +1018,23 @@ func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc) (any, error) {

// Retry if MaxRetries > 0 and the first execution failed
var joinedErrors error
if stepError != nil && params.MaxRetries > 0 {
if stepError != nil && stepOpts.MaxRetries > 0 {
joinedErrors = errors.Join(joinedErrors, stepError)

for retry := 1; retry <= params.MaxRetries; retry++ {
for retry := 1; retry <= stepOpts.MaxRetries; retry++ {
// Calculate delay for exponential backoff
delay := params.BaseInterval
delay := stepOpts.BaseInterval
if retry > 1 {
exponentialDelay := float64(params.BaseInterval) * math.Pow(params.BackoffFactor, float64(retry-1))
delay = time.Duration(math.Min(exponentialDelay, float64(params.MaxInterval)))
exponentialDelay := float64(stepOpts.BaseInterval) * math.Pow(stepOpts.BackoffFactor, float64(retry-1))
delay = time.Duration(math.Min(exponentialDelay, float64(stepOpts.MaxInterval)))
}

c.logger.Error("step failed, retrying", "step_name", params.StepName, "retry", retry, "max_retries", params.MaxRetries, "delay", delay, "error", stepError)
c.logger.Error("step failed, retrying", "step_name", stepOpts.StepName, "retry", retry, "max_retries", stepOpts.MaxRetries, "delay", delay, "error", stepError)

// Wait before retry
select {
case <-c.Done():
return nil, newStepExecutionError(stepState.workflowID, params.StepName, fmt.Sprintf("context cancelled during retry: %v", c.Err()))
return nil, newStepExecutionError(stepState.workflowID, stepOpts.StepName, fmt.Sprintf("context cancelled during retry: %v", c.Err()))
case <-time.After(delay):
// Continue to retry
}
Expand All @@ -1031,8 +1051,8 @@ func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc) (any, error) {
joinedErrors = errors.Join(joinedErrors, stepError)

// If max retries reached, create MaxStepRetriesExceeded error
if retry == params.MaxRetries {
stepError = newMaxStepRetriesExceededError(stepState.workflowID, params.StepName, params.MaxRetries, joinedErrors)
if retry == stepOpts.MaxRetries {
stepError = newMaxStepRetriesExceededError(stepState.workflowID, stepOpts.StepName, stepOpts.MaxRetries, joinedErrors)
break
}
}
Expand All @@ -1041,14 +1061,14 @@ func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc) (any, error) {
// Record the final result
dbInput := recordOperationResultDBInput{
workflowID: stepState.workflowID,
stepName: params.StepName,
stepName: stepOpts.StepName,
stepID: stepState.stepID,
err: stepError,
output: stepOutput,
}
recErr := c.systemDB.recordOperationResult(uncancellableCtx, dbInput)
if recErr != nil {
return nil, newStepExecutionError(stepState.workflowID, params.StepName, fmt.Sprintf("recording step outcome: %v", recErr))
return nil, newStepExecutionError(stepState.workflowID, stepOpts.StepName, fmt.Sprintf("recording step outcome: %v", recErr))
}

return stepOutput, stepError
Expand Down
Loading
Loading