Skip to content

Commit a680f8f

Browse files
authored
Functional options for steps (#59)
Our current method for capturing the reflection name of a step is flawed. The typed erase function (anonymous) always has the same name, and thus every call to RunAsStep overwrites it. This of course results in wrong data in the `operation_results` table. There is no good way to associated a unique key that the interface method can use on this map. Hence, let's just move to the cleaner functional options interface to handle step options, voiding the need for the in-memory map. Added a test (`TestSteps/checkStepName`) that fails without this change.
1 parent 3aaf85a commit a680f8f

File tree

3 files changed

+211
-80
lines changed

3 files changed

+211
-80
lines changed

dbos/dbos.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ type DBOSContext interface {
7575
Cancel() // Gracefully shutdown the DBOS runtime, waiting for workflows to complete and cleaning up resources
7676

7777
// Workflow operations
78-
RunAsStep(_ DBOSContext, fn StepFunc) (any, error) // Execute a function as a durable step within a workflow
78+
RunAsStep(_ DBOSContext, fn StepFunc, opts ...StepOption) (any, error) // Execute a function as a durable step within a workflow
7979
RunAsWorkflow(_ DBOSContext, fn WorkflowFunc, input any, opts ...WorkflowOption) (WorkflowHandle[any], error) // Start a new workflow execution
8080
Send(_ DBOSContext, input WorkflowSendInput) error // Send a message to another workflow
8181
Recv(_ DBOSContext, input WorkflowRecvInput) (any, error) // Receive a message sent to this workflow

dbos/workflow.go

Lines changed: 96 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ import (
88
"math"
99
"reflect"
1010
"runtime"
11-
"sync"
1211
"time"
1312

1413
"github.com/google/uuid"
@@ -247,8 +246,12 @@ func registerWorkflow(ctx DBOSContext, workflowFQN string, fn WrappedWorkflowFun
247246
}
248247

249248
// We need to get a mapping from custom name to FQN for registry lookups that might not know the FQN (queue, recovery)
249+
// We also panic if we found the name was already registered (this could happen if registering two different workflows under the same custom name)
250250
if len(customName) > 0 {
251-
c.workflowCustomNametoFQN.Store(customName, workflowFQN)
251+
if _, exists := c.workflowCustomNametoFQN.LoadOrStore(customName, workflowFQN); exists {
252+
c.logger.Error("workflow function already registered", "custom_name", customName)
253+
panic(newConflictingRegistrationError(customName))
254+
}
252255
} else {
253256
c.workflowCustomNametoFQN.Store(workflowFQN, workflowFQN) // Store the FQN as the custom name if none was provided
254257
}
@@ -827,70 +830,92 @@ type StepFunc func(ctx context.Context) (any, error)
827830
// GenericStepFunc represents a type-safe step function with a specific output type R.
828831
type GenericStepFunc[R any] func(ctx context.Context) (R, error)
829832

830-
// StepParamsKey is the context key for setting StepParams in a workflow context.
831-
// Use this key with the dbos.WithValue to configure steps.
832-
const StepParamsKey DBOSContextKey = "stepParams"
833-
834-
// StepParams configures retry behavior and identification for step execution.
835-
// These parameters can be set in the context using the StepParamsKey.
836-
type StepParams struct {
833+
// stepOptions holds the configuration for step execution using functional options pattern.
834+
type stepOptions struct {
837835
MaxRetries int // Maximum number of retry attempts (0 = no retries)
838836
BackoffFactor float64 // Exponential backoff multiplier between retries (default: 2.0)
839837
BaseInterval time.Duration // Initial delay between retries (default: 100ms)
840838
MaxInterval time.Duration // Maximum delay between retries (default: 5s)
841839
StepName string // Custom name for the step (defaults to function name)
842840
}
843841

844-
// setStepParamDefaults returns a StepParams struct with all defaults properly set
845-
func setStepParamDefaults(params *StepParams, stepName string) *StepParams {
846-
if params == nil {
847-
return &StepParams{
848-
MaxRetries: 0, // Default to no retries
849-
BackoffFactor: _DEFAULT_STEP_BACKOFF_FACTOR,
850-
BaseInterval: _DEFAULT_STEP_BASE_INTERVAL, // Default base interval
851-
MaxInterval: _DEFAULT_STEP_MAX_INTERVAL, // Default max interval
852-
StepName: func() string {
853-
if value, ok := typeErasedStepNameToStepName.Load(stepName); ok {
854-
return value.(string)
855-
}
856-
return "" // This should never happen
857-
}(),
858-
}
859-
}
860-
861-
// Set defaults for zero values
862-
if params.BackoffFactor == 0 {
863-
params.BackoffFactor = _DEFAULT_STEP_BACKOFF_FACTOR // Default backoff factor
842+
// setDefaults applies default values to stepOptions
843+
func (opts *stepOptions) setDefaults() {
844+
if opts.BackoffFactor == 0 {
845+
opts.BackoffFactor = _DEFAULT_STEP_BACKOFF_FACTOR
864846
}
865-
if params.BaseInterval == 0 {
866-
params.BaseInterval = _DEFAULT_STEP_BASE_INTERVAL // Default base interval
847+
if opts.BaseInterval == 0 {
848+
opts.BaseInterval = _DEFAULT_STEP_BASE_INTERVAL
867849
}
868-
if params.MaxInterval == 0 {
869-
params.MaxInterval = _DEFAULT_STEP_MAX_INTERVAL // Default max interval
850+
if opts.MaxInterval == 0 {
851+
opts.MaxInterval = _DEFAULT_STEP_MAX_INTERVAL
870852
}
871-
if len(params.StepName) == 0 {
872-
// If the step name is not provided, use the function name
873-
if value, ok := typeErasedStepNameToStepName.Load(stepName); ok {
874-
params.StepName = value.(string)
853+
}
854+
855+
// StepOption is a functional option for configuring step execution parameters.
856+
type StepOption func(*stepOptions)
857+
858+
// WithStepName sets a custom name for the step. If the step name has already been set
859+
// by a previous call to WithStepName, this option will be ignored to allow
860+
// multiple WithStepName calls without overriding the first one.
861+
func WithStepName(name string) StepOption {
862+
return func(opts *stepOptions) {
863+
if opts.StepName == "" {
864+
opts.StepName = name
875865
}
876866
}
867+
}
868+
869+
// WithStepMaxRetries sets the maximum number of retry attempts for the step.
870+
// A value of 0 means no retries (default behavior).
871+
func WithStepMaxRetries(maxRetries int) StepOption {
872+
return func(opts *stepOptions) {
873+
opts.MaxRetries = maxRetries
874+
}
875+
}
876+
877+
// WithBackoffFactor sets the exponential backoff multiplier between retries.
878+
// The delay between retries is calculated as: BaseInterval * (BackoffFactor^(retry-1))
879+
// Default value is 2.0.
880+
func WithBackoffFactor(factor float64) StepOption {
881+
return func(opts *stepOptions) {
882+
opts.BackoffFactor = factor
883+
}
884+
}
877885

878-
return params
886+
// WithBaseInterval sets the initial delay between retries.
887+
// Default value is 100ms.
888+
func WithBaseInterval(interval time.Duration) StepOption {
889+
return func(opts *stepOptions) {
890+
opts.BaseInterval = interval
891+
}
879892
}
880893

881-
var typeErasedStepNameToStepName sync.Map
894+
// WithMaxInterval sets the maximum delay between retries.
895+
// Default value is 5s.
896+
func WithMaxInterval(interval time.Duration) StepOption {
897+
return func(opts *stepOptions) {
898+
opts.MaxInterval = interval
899+
}
900+
}
882901

883902
// RunAsStep executes a function as a durable step within a workflow.
884903
// Steps provide at-least-once execution guarantees and automatic retry capabilities.
885904
// If a step has already been executed (e.g., during workflow recovery), its recorded
886905
// result is returned instead of re-executing the function.
887906
//
888-
// Steps can be configured with retry parameters by setting StepParams in the context:
907+
// Steps can be configured with functional options:
889908
//
890-
// stepCtx = context.WithValue(ctx, dbos.StepParamsKey, &dbos.StepParams{
891-
// MaxRetries: 3,
892-
// BaseInterval: 500 * time.Millisecond,
893-
// })
909+
// data, err := dbos.RunAsStep(ctx, func(ctx context.Context) ([]byte, error) {
910+
// return MyStep(ctx, "https://api.example.com/data")
911+
// }, dbos.WithStepMaxRetries(3), dbos.WithBaseInterval(500*time.Millisecond))
912+
//
913+
// Available options:
914+
// - WithStepName: Custom name for the step (only sets if not already set)
915+
// - WithStepMaxRetries: Maximum retry attempts (default: 0)
916+
// - WithBackoffFactor: Exponential backoff multiplier (default: 2.0)
917+
// - WithBaseInterval: Initial delay between retries (default: 100ms)
918+
// - WithMaxInterval: Maximum delay between retries (default: 5s)
894919
//
895920
// Example:
896921
//
@@ -904,17 +929,17 @@ var typeErasedStepNameToStepName sync.Map
904929
// }
905930
//
906931
// // Within a workflow:
907-
// data, err := dbos.RunAsStep(stepCtx, func(ctx context.Context) ([]byte, error) {
932+
// data, err := dbos.RunAsStep(ctx, func(ctx context.Context) ([]byte, error) {
908933
// return MyStep(ctx, "https://api.example.com/data")
909-
// })
934+
// }, dbos.WithStepName("FetchData"), dbos.WithStepMaxRetries(3))
910935
// if err != nil {
911936
// return nil, err
912937
// }
913938
//
914939
// Note that the function passed to RunAsStep must accept a context.Context as its first parameter
915940
// and this context *must* be the one specified in the function's signature (not the context passed to RunAsStep).
916941
// Under the hood, DBOS will augment the step's context and pass it to the function when executing it durably.
917-
func RunAsStep[R any](ctx DBOSContext, fn GenericStepFunc[R]) (R, error) {
942+
func RunAsStep[R any](ctx DBOSContext, fn GenericStepFunc[R], opts ...StepOption) (R, error) {
918943
if ctx == nil {
919944
return *new(R), newStepExecutionError("", "", "ctx cannot be nil")
920945
}
@@ -923,15 +948,14 @@ func RunAsStep[R any](ctx DBOSContext, fn GenericStepFunc[R]) (R, error) {
923948
return *new(R), newStepExecutionError("", "", "step function cannot be nil")
924949
}
925950

951+
// Append WithStepName option to ensure the step name is set. This will not erase a user-provided step name
926952
stepName := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name()
953+
opts = append(opts, WithStepName(stepName))
927954

928955
// Type-erase the function
929956
typeErasedFn := StepFunc(func(ctx context.Context) (any, error) { return fn(ctx) })
930-
typeErasedFnName := runtime.FuncForPC(reflect.ValueOf(typeErasedFn).Pointer()).Name()
931-
typeErasedStepNameToStepName.LoadOrStore(typeErasedFnName, stepName)
932957

933-
// Call the executor method and pass through the result/error
934-
result, err := ctx.RunAsStep(ctx, typeErasedFn)
958+
result, err := ctx.RunAsStep(ctx, typeErasedFn, opts...)
935959
// Step function could return a nil result
936960
if result == nil {
937961
return *new(R), err
@@ -944,23 +968,23 @@ func RunAsStep[R any](ctx DBOSContext, fn GenericStepFunc[R]) (R, error) {
944968
return typedResult, err
945969
}
946970

947-
func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc) (any, error) {
948-
// Look up for step parameters in the context and set defaults
949-
params, ok := c.Value(StepParamsKey).(*StepParams)
950-
if !ok {
951-
params = nil
971+
func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc, opts ...StepOption) (any, error) {
972+
// Process functional options
973+
stepOpts := &stepOptions{}
974+
for _, opt := range opts {
975+
opt(stepOpts)
952976
}
953-
params = setStepParamDefaults(params, runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name())
977+
stepOpts.setDefaults()
954978

955979
// Get workflow state from context
956980
wfState, ok := c.Value(workflowStateKey).(*workflowState)
957981
if !ok || wfState == nil {
958-
return nil, newStepExecutionError("", params.StepName, "workflow state not found in context: are you running this step within a workflow?")
982+
return nil, newStepExecutionError("", stepOpts.StepName, "workflow state not found in context: are you running this step within a workflow?")
959983
}
960984

961985
// This should not happen when called from the package-level RunAsStep
962986
if fn == nil {
963-
return nil, newStepExecutionError(wfState.workflowID, params.StepName, "step function cannot be nil")
987+
return nil, newStepExecutionError(wfState.workflowID, stepOpts.StepName, "step function cannot be nil")
964988
}
965989

966990
// If within a step, just run the function directly
@@ -982,10 +1006,10 @@ func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc) (any, error) {
9821006
recordedOutput, err := c.systemDB.checkOperationExecution(uncancellableCtx, checkOperationExecutionDBInput{
9831007
workflowID: stepState.workflowID,
9841008
stepID: stepState.stepID,
985-
stepName: params.StepName,
1009+
stepName: stepOpts.StepName,
9861010
})
9871011
if err != nil {
988-
return nil, newStepExecutionError(stepState.workflowID, params.StepName, fmt.Sprintf("checking operation execution: %v", err))
1012+
return nil, newStepExecutionError(stepState.workflowID, stepOpts.StepName, fmt.Sprintf("checking operation execution: %v", err))
9891013
}
9901014
if recordedOutput != nil {
9911015
return recordedOutput.output, recordedOutput.err
@@ -998,23 +1022,23 @@ func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc) (any, error) {
9981022

9991023
// Retry if MaxRetries > 0 and the first execution failed
10001024
var joinedErrors error
1001-
if stepError != nil && params.MaxRetries > 0 {
1025+
if stepError != nil && stepOpts.MaxRetries > 0 {
10021026
joinedErrors = errors.Join(joinedErrors, stepError)
10031027

1004-
for retry := 1; retry <= params.MaxRetries; retry++ {
1028+
for retry := 1; retry <= stepOpts.MaxRetries; retry++ {
10051029
// Calculate delay for exponential backoff
1006-
delay := params.BaseInterval
1030+
delay := stepOpts.BaseInterval
10071031
if retry > 1 {
1008-
exponentialDelay := float64(params.BaseInterval) * math.Pow(params.BackoffFactor, float64(retry-1))
1009-
delay = time.Duration(math.Min(exponentialDelay, float64(params.MaxInterval)))
1032+
exponentialDelay := float64(stepOpts.BaseInterval) * math.Pow(stepOpts.BackoffFactor, float64(retry-1))
1033+
delay = time.Duration(math.Min(exponentialDelay, float64(stepOpts.MaxInterval)))
10101034
}
10111035

1012-
c.logger.Error("step failed, retrying", "step_name", params.StepName, "retry", retry, "max_retries", params.MaxRetries, "delay", delay, "error", stepError)
1036+
c.logger.Error("step failed, retrying", "step_name", stepOpts.StepName, "retry", retry, "max_retries", stepOpts.MaxRetries, "delay", delay, "error", stepError)
10131037

10141038
// Wait before retry
10151039
select {
10161040
case <-c.Done():
1017-
return nil, newStepExecutionError(stepState.workflowID, params.StepName, fmt.Sprintf("context cancelled during retry: %v", c.Err()))
1041+
return nil, newStepExecutionError(stepState.workflowID, stepOpts.StepName, fmt.Sprintf("context cancelled during retry: %v", c.Err()))
10181042
case <-time.After(delay):
10191043
// Continue to retry
10201044
}
@@ -1031,8 +1055,8 @@ func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc) (any, error) {
10311055
joinedErrors = errors.Join(joinedErrors, stepError)
10321056

10331057
// If max retries reached, create MaxStepRetriesExceeded error
1034-
if retry == params.MaxRetries {
1035-
stepError = newMaxStepRetriesExceededError(stepState.workflowID, params.StepName, params.MaxRetries, joinedErrors)
1058+
if retry == stepOpts.MaxRetries {
1059+
stepError = newMaxStepRetriesExceededError(stepState.workflowID, stepOpts.StepName, stepOpts.MaxRetries, joinedErrors)
10361060
break
10371061
}
10381062
}
@@ -1041,14 +1065,14 @@ func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc) (any, error) {
10411065
// Record the final result
10421066
dbInput := recordOperationResultDBInput{
10431067
workflowID: stepState.workflowID,
1044-
stepName: params.StepName,
1068+
stepName: stepOpts.StepName,
10451069
stepID: stepState.stepID,
10461070
err: stepError,
10471071
output: stepOutput,
10481072
}
10491073
recErr := c.systemDB.recordOperationResult(uncancellableCtx, dbInput)
10501074
if recErr != nil {
1051-
return nil, newStepExecutionError(stepState.workflowID, params.StepName, fmt.Sprintf("recording step outcome: %v", recErr))
1075+
return nil, newStepExecutionError(stepState.workflowID, stepOpts.StepName, fmt.Sprintf("recording step outcome: %v", recErr))
10521076
}
10531077

10541078
return stepOutput, stepError

0 commit comments

Comments
 (0)