diff --git a/dbos/dbos.go b/dbos/dbos.go index 18cccf99..1d3a70db 100644 --- a/dbos/dbos.go +++ b/dbos/dbos.go @@ -61,7 +61,7 @@ type DBOSContext interface { Cancel() // Workflow operations - RunAsStep(_ DBOSContext, fn StepFunc, input any) (any, error) + RunAsStep(_ DBOSContext, fn StepFunc) (any, error) RunAsWorkflow(_ DBOSContext, fn WorkflowFunc, input any, opts ...WorkflowOption) (WorkflowHandle[any], error) Send(_ DBOSContext, input WorkflowSendInputInternal) error Recv(_ DBOSContext, input WorkflowRecvInput) (any, error) diff --git a/dbos/queues_test.go b/dbos/queues_test.go index 0b8733b9..eda6dd57 100644 --- a/dbos/queues_test.go +++ b/dbos/queues_test.go @@ -30,14 +30,16 @@ This suite tests */ func queueWorkflow(ctx DBOSContext, input string) (string, error) { - step1, err := RunAsStep(ctx, queueStep, input) + step1, err := RunAsStep(ctx, func(context context.Context) (string, error) { + return queueStep(context, input) + }) if err != nil { return "", fmt.Errorf("failed to run step: %v", err) } return step1, nil } -func queueStep(ctx context.Context, input string) (string, error) { +func queueStep(_ context.Context, input string) (string, error) { return input, nil } diff --git a/dbos/serialization_test.go b/dbos/serialization_test.go index bc8ebd4a..c367c89d 100644 --- a/dbos/serialization_test.go +++ b/dbos/serialization_test.go @@ -23,7 +23,9 @@ func encodingStepBuiltinTypes(_ context.Context, input int) (int, error) { } func encodingWorkflowBuiltinTypes(ctx DBOSContext, input string) (string, error) { - stepResult, err := RunAsStep(ctx, encodingStepBuiltinTypes, 123) + stepResult, err := RunAsStep(ctx, func(context context.Context) (int, error) { + return encodingStepBuiltinTypes(context, 123) + }) return fmt.Sprintf("%d", stepResult), fmt.Errorf("workflow error: %v", err) } @@ -49,13 +51,15 @@ type SimpleStruct struct { } func encodingWorkflowStruct(ctx DBOSContext, input WorkflowInputStruct) (StepOutputStruct, error) { - return RunAsStep(ctx, encodingStepStruct, StepInputStruct{ - A: input.A, - B: fmt.Sprintf("%d", input.B), + return RunAsStep(ctx, func(context context.Context) (StepOutputStruct, error) { + return encodingStepStruct(context, StepInputStruct{ + A: input.A, + B: fmt.Sprintf("%d", input.B), + }) }) } -func encodingStepStruct(ctx context.Context, input StepInputStruct) (StepOutputStruct, error) { +func encodingStepStruct(_ context.Context, input StepInputStruct) (StepOutputStruct, error) { return StepOutputStruct{ A: input, B: "processed by encodingStepStruct", diff --git a/dbos/workflow.go b/dbos/workflow.go index e6802307..a3eac4a2 100644 --- a/dbos/workflow.go +++ b/dbos/workflow.go @@ -684,8 +684,8 @@ func (c *dbosContext) RunAsWorkflow(_ DBOSContext, fn WorkflowFunc, input any, o /******* STEP FUNCTIONS *******/ /******************************/ -type StepFunc func(ctx context.Context, input any) (any, error) -type GenericStepFunc[P any, R any] func(ctx context.Context, input P) (R, error) +type StepFunc func(ctx context.Context) (any, error) +type GenericStepFunc[R any] func(ctx context.Context) (R, error) const StepParamsKey DBOSContextKey = "stepParams" @@ -729,7 +729,7 @@ func setStepParamDefaults(params *StepParams, stepName string) *StepParams { var typeErasedStepNameToStepName = make(map[string]string) -func RunAsStep[P any, R any](ctx DBOSContext, fn GenericStepFunc[P, R], input P) (R, error) { +func RunAsStep[R any](ctx DBOSContext, fn GenericStepFunc[R]) (R, error) { if ctx == nil { return *new(R), newStepExecutionError("", "", "ctx cannot be nil") } @@ -738,37 +738,27 @@ func RunAsStep[P any, R any](ctx DBOSContext, fn GenericStepFunc[P, R], input P) return *new(R), newStepExecutionError("", "", "step function cannot be nil") } - // Type-erase the function based on its actual type - typeErasedFn := StepFunc(func(ctx context.Context, input any) (any, error) { - typedInput, ok := input.(P) - if !ok { - return nil, newStepExecutionError("", "", fmt.Sprintf("unexpected input type: expected %T, got %T", *new(P), input)) - } - return fn(ctx, typedInput) - }) + stepName := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name() - typeErasedStepNameToStepName[runtime.FuncForPC(reflect.ValueOf(typeErasedFn).Pointer()).Name()] = runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name() + // Type-erase the function + typeErasedFn := StepFunc(func(ctx context.Context) (any, error) { return fn(ctx) }) + typeErasedStepNameToStepName[runtime.FuncForPC(reflect.ValueOf(typeErasedFn).Pointer()).Name()] = stepName - // Call the executor method - result, err := ctx.RunAsStep(ctx, typeErasedFn, input) - if err != nil { - // In case the errors comes from the DBOS step logic, the result will be nil and we must handle it - if result == nil { - return *new(R), err - } - return result.(R), err + // Call the executor method and pass through the result/error + result, err := ctx.RunAsStep(ctx, typeErasedFn) + // Step function could return a nil result + if result == nil { + return *new(R), err } - - // Type-check and cast the result + // Otherwise type-check and cast the result typedResult, ok := result.(R) if !ok { return *new(R), fmt.Errorf("unexpected result type: expected %T, got %T", *new(R), result) } - - return typedResult, nil + return typedResult, err } -func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc, input any) (any, error) { +func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc) (any, error) { // Get workflow state from context wfState, ok := c.Value(workflowStateKey).(*workflowState) if !ok || wfState == nil { @@ -776,8 +766,8 @@ func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc, input any) (any, err return nil, newStepExecutionError("", "", "workflow state not found in context: are you running this step within a workflow?") } + // This should not happen when called from the package-level RunAsStep if fn == nil { - // TODO: try to print step name return nil, newStepExecutionError(wfState.workflowID, "", "step function cannot be nil") } @@ -790,7 +780,7 @@ func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc, input any) (any, err // If within a step, just run the function directly if wfState.isWithinStep { - return fn(c, input) + return fn(c) } // Setup step state @@ -819,7 +809,7 @@ func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc, input any) (any, err // Spawn a child DBOSContext with the step state stepCtx := WithValue(c, workflowStateKey, &stepState) - stepOutput, stepError := fn(stepCtx, input) + stepOutput, stepError := fn(stepCtx) // Retry if MaxRetries > 0 and the first execution failed var joinedErrors error @@ -845,7 +835,7 @@ func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc, input any) (any, err } // Execute the retry - stepOutput, stepError = fn(stepCtx, input) + stepOutput, stepError = fn(stepCtx) // If successful, break if stepError == nil { diff --git a/dbos/workflows_test.go b/dbos/workflows_test.go index 3c5913cd..45112b5f 100644 --- a/dbos/workflows_test.go +++ b/dbos/workflows_test.go @@ -35,19 +35,23 @@ func simpleWorkflowError(dbosCtx DBOSContext, input string) (int, error) { } func simpleWorkflowWithStep(dbosCtx DBOSContext, input string) (string, error) { - return RunAsStep(dbosCtx, simpleStep, input) + return RunAsStep(dbosCtx, func(ctx context.Context) (string, error) { + return simpleStep(ctx) + }) } -func simpleStep(ctx context.Context, input string) (string, error) { +func simpleStep(_ context.Context) (string, error) { return "from step", nil } -func simpleStepError(ctx context.Context, input string) (string, error) { +func simpleStepError(_ context.Context) (string, error) { return "", fmt.Errorf("step failure") } func simpleWorkflowWithStepError(dbosCtx DBOSContext, input string) (string, error) { - return RunAsStep(dbosCtx, simpleStepError, input) + return RunAsStep(dbosCtx, func(ctx context.Context) (string, error) { + return simpleStepError(ctx) + }) } // idempotencyWorkflow increments a global counter and returns the input @@ -292,38 +296,44 @@ func TestWorkflowsRegistration(t *testing.T) { } } -func stepWithinAStep(ctx context.Context, input string) (string, error) { - return simpleStep(ctx, input) +func stepWithinAStep(ctx context.Context) (string, error) { + return simpleStep(ctx) } func stepWithinAStepWorkflow(dbosCtx DBOSContext, input string) (string, error) { - return RunAsStep(dbosCtx, stepWithinAStep, input) + return RunAsStep(dbosCtx, func(ctx context.Context) (string, error) { + return stepWithinAStep(ctx) + }) } // Global counter for retry testing var stepRetryAttemptCount int -func stepRetryAlwaysFailsStep(ctx context.Context, input string) (string, error) { +func stepRetryAlwaysFailsStep(ctx context.Context) (string, error) { stepRetryAttemptCount++ return "", fmt.Errorf("always fails - attempt %d", stepRetryAttemptCount) } var stepIdempotencyCounter int -func stepIdempotencyTest(ctx context.Context, input int) (string, error) { +func stepIdempotencyTest(ctx context.Context) (string, error) { stepIdempotencyCounter++ return "", nil } func stepRetryWorkflow(dbosCtx DBOSContext, input string) (string, error) { - RunAsStep(dbosCtx, stepIdempotencyTest, 1) + RunAsStep(dbosCtx, func(ctx context.Context) (string, error) { + return stepIdempotencyTest(ctx) + }) stepCtx := WithValue(dbosCtx, StepParamsKey, &StepParams{ MaxRetries: 5, BaseInterval: 1 * time.Millisecond, MaxInterval: 10 * time.Millisecond, }) - return RunAsStep(stepCtx, stepRetryAlwaysFailsStep, input) + return RunAsStep(stepCtx, func(ctx context.Context) (string, error) { + return stepRetryAlwaysFailsStep(ctx) + }) } func TestSteps(t *testing.T) { @@ -335,7 +345,9 @@ func TestSteps(t *testing.T) { t.Run("StepsMustRunInsideWorkflows", func(t *testing.T) { // Attempt to run a step outside of a workflow context - _, err := RunAsStep(dbosCtx, simpleStep, "test") + _, err := RunAsStep(dbosCtx, func(ctx context.Context) (string, error) { + return simpleStep(ctx) + }) if err == nil { t.Fatal("expected error when running step outside of workflow context, but got none") } @@ -411,7 +423,7 @@ func TestSteps(t *testing.T) { } // Verify the error contains the step name and max retries - expectedErrorMessage := "dbos.stepRetryAlwaysFailsStep has exceeded its maximum of 5 retries" + expectedErrorMessage := "has exceeded its maximum of 5 retries" if !strings.Contains(dbosErr.Message, expectedErrorMessage) { t.Fatalf("expected error message to contain '%s', got '%s'", expectedErrorMessage, dbosErr.Message) } @@ -470,7 +482,9 @@ func TestChildWorkflow(t *testing.T) { return "", fmt.Errorf("expected childWf workflow ID to be %s, got %s", expectedCurrentID, workflowID) } // Steps of a child workflow start with an incremented step ID, because the first step ID is allocated to the child workflow - return RunAsStep(dbosCtx, simpleStep, "") + return RunAsStep(dbosCtx, func(ctx context.Context) (string, error) { + return simpleStep(ctx) + }) } RegisterWorkflow(dbosCtx, childWf) @@ -644,7 +658,9 @@ func TestChildWorkflow(t *testing.T) { customChildID := uuid.NewString() simpleChildWf := func(dbosCtx DBOSContext, input string) (string, error) { - return RunAsStep(dbosCtx, simpleStep, input) + return RunAsStep(dbosCtx, func(ctx context.Context) (string, error) { + return simpleStep(ctx) + }) } RegisterWorkflow(dbosCtx, simpleChildWf) @@ -713,13 +729,15 @@ func TestChildWorkflow(t *testing.T) { // Idempotency workflows moved to test functions func idempotencyWorkflow(dbosCtx DBOSContext, input string) (string, error) { - RunAsStep(dbosCtx, incrementCounter, int64(1)) + RunAsStep(dbosCtx, func(ctx context.Context) (int64, error) { + return incrementCounter(ctx, int64(1)) + }) return input, nil } var blockingStepStopEvent *Event -func blockingStep(ctx context.Context, input string) (string, error) { +func blockingStep(_ context.Context) (string, error) { blockingStepStopEvent.Wait() return "", nil } @@ -727,9 +745,13 @@ func blockingStep(ctx context.Context, input string) (string, error) { var idempotencyWorkflowWithStepEvent *Event func idempotencyWorkflowWithStep(dbosCtx DBOSContext, input string) (int64, error) { - RunAsStep(dbosCtx, incrementCounter, int64(1)) + RunAsStep(dbosCtx, func(ctx context.Context) (int64, error) { + return incrementCounter(ctx, int64(1)) + }) idempotencyWorkflowWithStepEvent.Set() - RunAsStep(dbosCtx, blockingStep, input) + RunAsStep(dbosCtx, func(ctx context.Context) (string, error) { + return blockingStep(ctx) + }) return idempotencyCounter, nil } @@ -1131,9 +1153,9 @@ func TestScheduledWorkflows(t *testing.T) { } // Stop the workflowScheduler and check if it stops executing - currentCounter := counter dbosCtx.(*dbosContext).getWorkflowScheduler().Stop() time.Sleep(3 * time.Second) // Wait a bit to ensure no more executions + currentCounter := counter // If more scheduled executions happen, this can also trigger a data race. If the scheduler is correct, there should be no race. if counter >= currentCounter+2 { t.Fatalf("Scheduled workflow continued executing after stopping scheduler: %d (expected < %d)", counter, currentCounter+2) } @@ -1253,7 +1275,9 @@ func stepThatCallsSend(ctx context.Context, input sendWorkflowInput) (string, er } func workflowThatCallsSendInStep(ctx DBOSContext, input sendWorkflowInput) (string, error) { - return RunAsStep(ctx, stepThatCallsSend, input) + return RunAsStep(ctx, func(context context.Context) (string, error) { + return stepThatCallsSend(context, input) + }) } type sendRecvType struct { @@ -2193,7 +2217,7 @@ func TestWorkflowTimeout(t *testing.T) { } }) - waitForCancelStep := func(ctx context.Context, _ string) (string, error) { + waitForCancelStep := func(ctx context.Context) (string, error) { // This step will trigger cancellation of the entire workflow context <-ctx.Done() if !errors.Is(ctx.Err(), context.Canceled) && !errors.Is(ctx.Err(), context.DeadlineExceeded) { @@ -2203,7 +2227,9 @@ func TestWorkflowTimeout(t *testing.T) { } waitForCancelWorkflowWithStep := func(ctx DBOSContext, _ string) (string, error) { - return RunAsStep(ctx, waitForCancelStep, "trigger-cancellation") + return RunAsStep(ctx, func(context context.Context) (string, error) { + return waitForCancelStep(context) + }) } RegisterWorkflow(dbosCtx, waitForCancelWorkflowWithStep) @@ -2240,7 +2266,9 @@ func TestWorkflowTimeout(t *testing.T) { // The timeout will trigger a step error, the workflow can do whatever it wants with that error stepCtx, stepCancelFunc := WithTimeout(ctx, 1*time.Millisecond) defer stepCancelFunc() // Ensure we clean up the context - _, err := RunAsStep(stepCtx, waitForCancelStep, "short-step-timeout") + _, err := RunAsStep(stepCtx, func(context context.Context) (string, error) { + return waitForCancelStep(context) + }) if !errors.Is(err, context.DeadlineExceeded) { t.Fatalf("expected step to timeout, got: %v", err) } @@ -2287,7 +2315,9 @@ func TestWorkflowTimeout(t *testing.T) { // This workflow will run a step that is not cancelable. // What this means is the workflow *will* be cancelled, but the step will run normally stepCtx := WithoutCancel(ctx) - res, err := RunAsStep(stepCtx, detachedStep, timeout*2) + res, err := RunAsStep(stepCtx, func(context context.Context) (string, error) { + return detachedStep(context, timeout*2) + }) if err != nil { t.Fatalf("failed to run detached step: %v", err) }