diff --git a/dbos/dbos.go b/dbos/dbos.go index a7838177..1dce00a9 100644 --- a/dbos/dbos.go +++ b/dbos/dbos.go @@ -75,7 +75,7 @@ type DBOSContext interface { Cancel() // Gracefully shutdown the DBOS runtime, waiting for workflows to complete and cleaning up resources // Workflow operations - RunAsStep(_ DBOSContext, fn StepFunc) (any, error) // Execute a function as a durable step within a workflow + RunAsStep(_ DBOSContext, fn StepFunc, opts ...StepOption) (any, error) // Execute a function as a durable step within a workflow RunAsWorkflow(_ DBOSContext, fn WorkflowFunc, input any, opts ...WorkflowOption) (WorkflowHandle[any], error) // Start a new workflow execution Send(_ DBOSContext, input WorkflowSendInput) error // Send a message to another workflow Recv(_ DBOSContext, input WorkflowRecvInput) (any, error) // Receive a message sent to this workflow diff --git a/dbos/workflow.go b/dbos/workflow.go index b1657541..39315308 100644 --- a/dbos/workflow.go +++ b/dbos/workflow.go @@ -8,7 +8,6 @@ import ( "math" "reflect" "runtime" - "sync" "time" "github.com/google/uuid" @@ -247,8 +246,12 @@ func registerWorkflow(ctx DBOSContext, workflowFQN string, fn WrappedWorkflowFun } // We need to get a mapping from custom name to FQN for registry lookups that might not know the FQN (queue, recovery) + // We also panic if we found the name was already registered (this could happen if registering two different workflows under the same custom name) if len(customName) > 0 { - c.workflowCustomNametoFQN.Store(customName, workflowFQN) + if _, exists := c.workflowCustomNametoFQN.LoadOrStore(customName, workflowFQN); exists { + c.logger.Error("workflow function already registered", "custom_name", customName) + panic(newConflictingRegistrationError(customName)) + } } else { c.workflowCustomNametoFQN.Store(workflowFQN, workflowFQN) // Store the FQN as the custom name if none was provided } @@ -827,13 +830,8 @@ type StepFunc func(ctx context.Context) (any, error) // GenericStepFunc represents a type-safe step function with a specific output type R. type GenericStepFunc[R any] func(ctx context.Context) (R, error) -// StepParamsKey is the context key for setting StepParams in a workflow context. -// Use this key with the dbos.WithValue to configure steps. -const StepParamsKey DBOSContextKey = "stepParams" - -// StepParams configures retry behavior and identification for step execution. -// These parameters can be set in the context using the StepParamsKey. -type StepParams struct { +// stepOptions holds the configuration for step execution using functional options pattern. +type stepOptions struct { MaxRetries int // Maximum number of retry attempts (0 = no retries) BackoffFactor float64 // Exponential backoff multiplier between retries (default: 2.0) BaseInterval time.Duration // Initial delay between retries (default: 100ms) @@ -841,56 +839,83 @@ type StepParams struct { StepName string // Custom name for the step (defaults to function name) } -// setStepParamDefaults returns a StepParams struct with all defaults properly set -func setStepParamDefaults(params *StepParams, stepName string) *StepParams { - if params == nil { - return &StepParams{ - MaxRetries: 0, // Default to no retries - BackoffFactor: _DEFAULT_STEP_BACKOFF_FACTOR, - BaseInterval: _DEFAULT_STEP_BASE_INTERVAL, // Default base interval - MaxInterval: _DEFAULT_STEP_MAX_INTERVAL, // Default max interval - StepName: func() string { - if value, ok := typeErasedStepNameToStepName.Load(stepName); ok { - return value.(string) - } - return "" // This should never happen - }(), - } - } - - // Set defaults for zero values - if params.BackoffFactor == 0 { - params.BackoffFactor = _DEFAULT_STEP_BACKOFF_FACTOR // Default backoff factor +// setDefaults applies default values to stepOptions +func (opts *stepOptions) setDefaults() { + if opts.BackoffFactor == 0 { + opts.BackoffFactor = _DEFAULT_STEP_BACKOFF_FACTOR } - if params.BaseInterval == 0 { - params.BaseInterval = _DEFAULT_STEP_BASE_INTERVAL // Default base interval + if opts.BaseInterval == 0 { + opts.BaseInterval = _DEFAULT_STEP_BASE_INTERVAL } - if params.MaxInterval == 0 { - params.MaxInterval = _DEFAULT_STEP_MAX_INTERVAL // Default max interval + if opts.MaxInterval == 0 { + opts.MaxInterval = _DEFAULT_STEP_MAX_INTERVAL } - if len(params.StepName) == 0 { - // If the step name is not provided, use the function name - if value, ok := typeErasedStepNameToStepName.Load(stepName); ok { - params.StepName = value.(string) +} + +// StepOption is a functional option for configuring step execution parameters. +type StepOption func(*stepOptions) + +// WithStepName sets a custom name for the step. If the step name has already been set +// by a previous call to WithStepName, this option will be ignored to allow +// multiple WithStepName calls without overriding the first one. +func WithStepName(name string) StepOption { + return func(opts *stepOptions) { + if opts.StepName == "" { + opts.StepName = name } } +} + +// WithStepMaxRetries sets the maximum number of retry attempts for the step. +// A value of 0 means no retries (default behavior). +func WithStepMaxRetries(maxRetries int) StepOption { + return func(opts *stepOptions) { + opts.MaxRetries = maxRetries + } +} + +// WithBackoffFactor sets the exponential backoff multiplier between retries. +// The delay between retries is calculated as: BaseInterval * (BackoffFactor^(retry-1)) +// Default value is 2.0. +func WithBackoffFactor(factor float64) StepOption { + return func(opts *stepOptions) { + opts.BackoffFactor = factor + } +} - return params +// WithBaseInterval sets the initial delay between retries. +// Default value is 100ms. +func WithBaseInterval(interval time.Duration) StepOption { + return func(opts *stepOptions) { + opts.BaseInterval = interval + } } -var typeErasedStepNameToStepName sync.Map +// WithMaxInterval sets the maximum delay between retries. +// Default value is 5s. +func WithMaxInterval(interval time.Duration) StepOption { + return func(opts *stepOptions) { + opts.MaxInterval = interval + } +} // RunAsStep executes a function as a durable step within a workflow. // Steps provide at-least-once execution guarantees and automatic retry capabilities. // If a step has already been executed (e.g., during workflow recovery), its recorded // result is returned instead of re-executing the function. // -// Steps can be configured with retry parameters by setting StepParams in the context: +// Steps can be configured with functional options: // -// stepCtx = context.WithValue(ctx, dbos.StepParamsKey, &dbos.StepParams{ -// MaxRetries: 3, -// BaseInterval: 500 * time.Millisecond, -// }) +// data, err := dbos.RunAsStep(ctx, func(ctx context.Context) ([]byte, error) { +// return MyStep(ctx, "https://api.example.com/data") +// }, dbos.WithStepMaxRetries(3), dbos.WithBaseInterval(500*time.Millisecond)) +// +// Available options: +// - WithStepName: Custom name for the step (only sets if not already set) +// - WithStepMaxRetries: Maximum retry attempts (default: 0) +// - WithBackoffFactor: Exponential backoff multiplier (default: 2.0) +// - WithBaseInterval: Initial delay between retries (default: 100ms) +// - WithMaxInterval: Maximum delay between retries (default: 5s) // // Example: // @@ -904,9 +929,9 @@ var typeErasedStepNameToStepName sync.Map // } // // // Within a workflow: -// data, err := dbos.RunAsStep(stepCtx, func(ctx context.Context) ([]byte, error) { +// data, err := dbos.RunAsStep(ctx, func(ctx context.Context) ([]byte, error) { // return MyStep(ctx, "https://api.example.com/data") -// }) +// }, dbos.WithStepName("FetchData"), dbos.WithStepMaxRetries(3)) // if err != nil { // return nil, err // } @@ -914,7 +939,7 @@ var typeErasedStepNameToStepName sync.Map // Note that the function passed to RunAsStep must accept a context.Context as its first parameter // and this context *must* be the one specified in the function's signature (not the context passed to RunAsStep). // Under the hood, DBOS will augment the step's context and pass it to the function when executing it durably. -func RunAsStep[R any](ctx DBOSContext, fn GenericStepFunc[R]) (R, error) { +func RunAsStep[R any](ctx DBOSContext, fn GenericStepFunc[R], opts ...StepOption) (R, error) { if ctx == nil { return *new(R), newStepExecutionError("", "", "ctx cannot be nil") } @@ -923,15 +948,14 @@ func RunAsStep[R any](ctx DBOSContext, fn GenericStepFunc[R]) (R, error) { return *new(R), newStepExecutionError("", "", "step function cannot be nil") } + // Append WithStepName option to ensure the step name is set. This will not erase a user-provided step name stepName := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name() + opts = append(opts, WithStepName(stepName)) // Type-erase the function typeErasedFn := StepFunc(func(ctx context.Context) (any, error) { return fn(ctx) }) - typeErasedFnName := runtime.FuncForPC(reflect.ValueOf(typeErasedFn).Pointer()).Name() - typeErasedStepNameToStepName.LoadOrStore(typeErasedFnName, stepName) - // Call the executor method and pass through the result/error - result, err := ctx.RunAsStep(ctx, typeErasedFn) + result, err := ctx.RunAsStep(ctx, typeErasedFn, opts...) // Step function could return a nil result if result == nil { return *new(R), err @@ -944,23 +968,23 @@ func RunAsStep[R any](ctx DBOSContext, fn GenericStepFunc[R]) (R, error) { return typedResult, err } -func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc) (any, error) { - // Look up for step parameters in the context and set defaults - params, ok := c.Value(StepParamsKey).(*StepParams) - if !ok { - params = nil +func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc, opts ...StepOption) (any, error) { + // Process functional options + stepOpts := &stepOptions{} + for _, opt := range opts { + opt(stepOpts) } - params = setStepParamDefaults(params, runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name()) + stepOpts.setDefaults() // Get workflow state from context wfState, ok := c.Value(workflowStateKey).(*workflowState) if !ok || wfState == nil { - return nil, newStepExecutionError("", params.StepName, "workflow state not found in context: are you running this step within a workflow?") + return nil, newStepExecutionError("", stepOpts.StepName, "workflow state not found in context: are you running this step within a workflow?") } // This should not happen when called from the package-level RunAsStep if fn == nil { - return nil, newStepExecutionError(wfState.workflowID, params.StepName, "step function cannot be nil") + return nil, newStepExecutionError(wfState.workflowID, stepOpts.StepName, "step function cannot be nil") } // If within a step, just run the function directly @@ -982,10 +1006,10 @@ func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc) (any, error) { recordedOutput, err := c.systemDB.checkOperationExecution(uncancellableCtx, checkOperationExecutionDBInput{ workflowID: stepState.workflowID, stepID: stepState.stepID, - stepName: params.StepName, + stepName: stepOpts.StepName, }) if err != nil { - return nil, newStepExecutionError(stepState.workflowID, params.StepName, fmt.Sprintf("checking operation execution: %v", err)) + return nil, newStepExecutionError(stepState.workflowID, stepOpts.StepName, fmt.Sprintf("checking operation execution: %v", err)) } if recordedOutput != nil { return recordedOutput.output, recordedOutput.err @@ -998,23 +1022,23 @@ func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc) (any, error) { // Retry if MaxRetries > 0 and the first execution failed var joinedErrors error - if stepError != nil && params.MaxRetries > 0 { + if stepError != nil && stepOpts.MaxRetries > 0 { joinedErrors = errors.Join(joinedErrors, stepError) - for retry := 1; retry <= params.MaxRetries; retry++ { + for retry := 1; retry <= stepOpts.MaxRetries; retry++ { // Calculate delay for exponential backoff - delay := params.BaseInterval + delay := stepOpts.BaseInterval if retry > 1 { - exponentialDelay := float64(params.BaseInterval) * math.Pow(params.BackoffFactor, float64(retry-1)) - delay = time.Duration(math.Min(exponentialDelay, float64(params.MaxInterval))) + exponentialDelay := float64(stepOpts.BaseInterval) * math.Pow(stepOpts.BackoffFactor, float64(retry-1)) + delay = time.Duration(math.Min(exponentialDelay, float64(stepOpts.MaxInterval))) } - c.logger.Error("step failed, retrying", "step_name", params.StepName, "retry", retry, "max_retries", params.MaxRetries, "delay", delay, "error", stepError) + c.logger.Error("step failed, retrying", "step_name", stepOpts.StepName, "retry", retry, "max_retries", stepOpts.MaxRetries, "delay", delay, "error", stepError) // Wait before retry select { case <-c.Done(): - return nil, newStepExecutionError(stepState.workflowID, params.StepName, fmt.Sprintf("context cancelled during retry: %v", c.Err())) + return nil, newStepExecutionError(stepState.workflowID, stepOpts.StepName, fmt.Sprintf("context cancelled during retry: %v", c.Err())) case <-time.After(delay): // Continue to retry } @@ -1031,8 +1055,8 @@ func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc) (any, error) { joinedErrors = errors.Join(joinedErrors, stepError) // If max retries reached, create MaxStepRetriesExceeded error - if retry == params.MaxRetries { - stepError = newMaxStepRetriesExceededError(stepState.workflowID, params.StepName, params.MaxRetries, joinedErrors) + if retry == stepOpts.MaxRetries { + stepError = newMaxStepRetriesExceededError(stepState.workflowID, stepOpts.StepName, stepOpts.MaxRetries, joinedErrors) break } } @@ -1041,14 +1065,14 @@ func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc) (any, error) { // Record the final result dbInput := recordOperationResultDBInput{ workflowID: stepState.workflowID, - stepName: params.StepName, + stepName: stepOpts.StepName, stepID: stepState.stepID, err: stepError, output: stepOutput, } recErr := c.systemDB.recordOperationResult(uncancellableCtx, dbInput) if recErr != nil { - return nil, newStepExecutionError(stepState.workflowID, params.StepName, fmt.Sprintf("recording step outcome: %v", recErr)) + return nil, newStepExecutionError(stepState.workflowID, stepOpts.StepName, fmt.Sprintf("recording step outcome: %v", recErr)) } return stepOutput, stepError diff --git a/dbos/workflows_test.go b/dbos/workflows_test.go index 8457ab1e..1e82d75f 100644 --- a/dbos/workflows_test.go +++ b/dbos/workflows_test.go @@ -327,6 +327,24 @@ func TestWorkflowsRegistration(t *testing.T) { RegisterWorkflow(freshCtx, simpleWorkflow, WithWorkflowName("custom-workflow")) }) + t.Run("DifferentWorkflowsSameCustomName", func(t *testing.T) { + // Create a fresh DBOS context for this test + freshCtx := setupDBOS(t, false, false) // Don't check for leaks and don't reset DB + + // First registration with custom name should work + RegisterWorkflow(freshCtx, simpleWorkflow, WithWorkflowName("same-name")) + + // Second registration of different workflow with same custom name should panic with ConflictingRegistrationError + defer func() { + r := recover() + require.NotNil(t, r, "expected panic from registering different workflows with same custom name but got none") + dbosErr, ok := r.(*DBOSError) + require.True(t, ok, "expected panic to be *DBOSError, got %T", r) + assert.Equal(t, ConflictingRegistrationError, dbosErr.Code) + }() + RegisterWorkflow(freshCtx, simpleWorkflowError, WithWorkflowName("same-name")) + }) + t.Run("RegisterAfterLaunchPanics", func(t *testing.T) { // Create a fresh DBOS context for this test freshCtx := setupDBOS(t, false, false) // Don't check for leaks and don't reset DB @@ -375,15 +393,26 @@ func stepRetryWorkflow(dbosCtx DBOSContext, input string) (string, error) { RunAsStep(dbosCtx, func(ctx context.Context) (string, error) { return stepIdempotencyTest(ctx) }) - stepCtx := WithValue(dbosCtx, StepParamsKey, &StepParams{ - MaxRetries: 5, - BaseInterval: 1 * time.Millisecond, - MaxInterval: 10 * time.Millisecond, - }) - return RunAsStep(stepCtx, func(ctx context.Context) (string, error) { + return RunAsStep(dbosCtx, func(ctx context.Context) (string, error) { return stepRetryAlwaysFailsStep(ctx) - }) + }, WithStepMaxRetries(5), WithBaseInterval(1*time.Millisecond), WithMaxInterval(10*time.Millisecond)) +} + +func step1(_ context.Context) (string, error) { + return "", nil +} + +func testStepWf1(dbosCtx DBOSContext, input string) (string, error) { + return RunAsStep(dbosCtx, step1) +} + +func step2(_ context.Context) (string, error) { + return "", nil +} + +func testStepWf2(dbosCtx DBOSContext, input string) (string, error) { + return RunAsStep(dbosCtx, step2) } func TestSteps(t *testing.T) { @@ -392,6 +421,8 @@ func TestSteps(t *testing.T) { // Create workflows with executor RegisterWorkflow(dbosCtx, stepWithinAStepWorkflow) RegisterWorkflow(dbosCtx, stepRetryWorkflow) + RegisterWorkflow(dbosCtx, testStepWf1) + RegisterWorkflow(dbosCtx, testStepWf2) t.Run("StepsMustRunInsideWorkflows", func(t *testing.T) { // Attempt to run a step outside of a workflow context @@ -469,6 +500,82 @@ func TestSteps(t *testing.T) { // Verify the idempotency step was executed only once assert.Equal(t, 1, stepIdempotencyCounter, "expected idempotency step to be executed only once") }) + + t.Run("checkStepName", func(t *testing.T) { + // Run first workflow with custom step name + handle1, err := RunAsWorkflow(dbosCtx, testStepWf1, "test-input-1") + require.NoError(t, err, "failed to run testStepWf1") + _, err = handle1.GetResult() + require.NoError(t, err, "failed to get result from testStepWf1") + + // Run second workflow with custom step name + handle2, err := RunAsWorkflow(dbosCtx, testStepWf2, "test-input-2") + require.NoError(t, err, "failed to run testStepWf2") + _, err = handle2.GetResult() + require.NoError(t, err, "failed to get result from testStepWf2") + + // Get workflow steps for first workflow and check step name + steps1, err := dbosCtx.(*dbosContext).systemDB.getWorkflowSteps(dbosCtx, handle1.GetWorkflowID()) + require.NoError(t, err, "failed to get workflow steps for testStepWf1") + require.Len(t, steps1, 1, "expected 1 step in testStepWf1") + s1 := steps1[0] + expectedStepName1 := runtime.FuncForPC(reflect.ValueOf(step1).Pointer()).Name() + assert.Equal(t, expectedStepName1, s1.StepName, "expected step name to match runtime function name") + + // Get workflow steps for second workflow and check step name + steps2, err := dbosCtx.(*dbosContext).systemDB.getWorkflowSteps(dbosCtx, handle2.GetWorkflowID()) + require.NoError(t, err, "failed to get workflow steps for testStepWf2") + require.Len(t, steps2, 1, "expected 1 step in testStepWf2") + s2 := steps2[0] + expectedStepName2 := runtime.FuncForPC(reflect.ValueOf(step2).Pointer()).Name() + assert.Equal(t, expectedStepName2, s2.StepName, "expected step name to match runtime function name") + }) + + t.Run("customStepNames", func(t *testing.T) { + // Create a workflow that uses custom step names + customNameWorkflow := func(dbosCtx DBOSContext, input string) (string, error) { + // Run a step with a custom name + result1, err := RunAsStep(dbosCtx, func(ctx context.Context) (string, error) { + return "custom-step-1-result", nil + }, WithStepName("MyCustomStep1")) + if err != nil { + return "", err + } + + // Run another step with a different custom name + result2, err := RunAsStep(dbosCtx, func(ctx context.Context) (string, error) { + return "custom-step-2-result", nil + }, WithStepName("MyCustomStep2")) + if err != nil { + return "", err + } + + return result1 + "-" + result2, nil + } + + RegisterWorkflow(dbosCtx, customNameWorkflow) + + // Execute the workflow + handle, err := RunAsWorkflow(dbosCtx, customNameWorkflow, "test-input") + require.NoError(t, err, "failed to run workflow with custom step names") + + result, err := handle.GetResult() + require.NoError(t, err, "failed to get result from workflow with custom step names") + assert.Equal(t, "custom-step-1-result-custom-step-2-result", result) + + // Verify the custom step names were recorded + steps, err := dbosCtx.(*dbosContext).systemDB.getWorkflowSteps(dbosCtx, handle.GetWorkflowID()) + require.NoError(t, err, "failed to get workflow steps") + require.Len(t, steps, 2, "expected 2 steps") + + // Check that the first step has the custom name + assert.Equal(t, "MyCustomStep1", steps[0].StepName, "expected first step to have custom name") + assert.Equal(t, 0, steps[0].StepID) + + // Check that the second step has the custom name + assert.Equal(t, "MyCustomStep2", steps[1].StepName, "expected second step to have custom name") + assert.Equal(t, 1, steps[1].StepID) + }) } func TestChildWorkflow(t *testing.T) {