Skip to content

Commit 7935dfb

Browse files
committed
accept a closure in RunAsStep
1 parent af4073d commit 7935dfb

File tree

5 files changed

+84
-58
lines changed

5 files changed

+84
-58
lines changed

dbos/dbos.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ type DBOSContext interface {
6161
Cancel()
6262

6363
// Workflow operations
64-
RunAsStep(_ DBOSContext, fn StepFunc, input any) (any, error)
64+
RunAsStep(_ DBOSContext, fn StepFunc) (any, error)
6565
RunAsWorkflow(_ DBOSContext, fn WorkflowFunc, input any, opts ...WorkflowOption) (WorkflowHandle[any], error)
6666
Send(_ DBOSContext, input WorkflowSendInputInternal) error
6767
Recv(_ DBOSContext, input WorkflowRecvInput) (any, error)

dbos/queues_test.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,16 @@ This suite tests
3030
*/
3131

3232
func queueWorkflow(ctx DBOSContext, input string) (string, error) {
33-
step1, err := RunAsStep(ctx, queueStep, input)
33+
step1, err := RunAsStep[string](ctx, func(context context.Context) (string, error) {
34+
return queueStep(context, input)
35+
})
3436
if err != nil {
3537
return "", fmt.Errorf("failed to run step: %v", err)
3638
}
3739
return step1, nil
3840
}
3941

40-
func queueStep(ctx context.Context, input string) (string, error) {
42+
func queueStep(_ context.Context, input string) (string, error) {
4143
return input, nil
4244
}
4345

dbos/serialization_test.go

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ func encodingStepBuiltinTypes(_ context.Context, input int) (int, error) {
2323
}
2424

2525
func encodingWorkflowBuiltinTypes(ctx DBOSContext, input string) (string, error) {
26-
stepResult, err := RunAsStep(ctx, encodingStepBuiltinTypes, 123)
26+
stepResult, err := RunAsStep[int](ctx, func(context context.Context) (int, error) {
27+
return encodingStepBuiltinTypes(context, 123)
28+
})
2729
return fmt.Sprintf("%d", stepResult), fmt.Errorf("workflow error: %v", err)
2830
}
2931

@@ -49,13 +51,15 @@ type SimpleStruct struct {
4951
}
5052

5153
func encodingWorkflowStruct(ctx DBOSContext, input WorkflowInputStruct) (StepOutputStruct, error) {
52-
return RunAsStep(ctx, encodingStepStruct, StepInputStruct{
53-
A: input.A,
54-
B: fmt.Sprintf("%d", input.B),
54+
return RunAsStep[StepOutputStruct](ctx, func(context context.Context) (StepOutputStruct, error) {
55+
return encodingStepStruct(context, StepInputStruct{
56+
A: input.A,
57+
B: fmt.Sprintf("%d", input.B),
58+
})
5559
})
5660
}
5761

58-
func encodingStepStruct(ctx context.Context, input StepInputStruct) (StepOutputStruct, error) {
62+
func encodingStepStruct(_ context.Context, input StepInputStruct) (StepOutputStruct, error) {
5963
return StepOutputStruct{
6064
A: input,
6165
B: "processed by encodingStepStruct",

dbos/workflow.go

Lines changed: 17 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -684,8 +684,8 @@ func (c *dbosContext) RunAsWorkflow(_ DBOSContext, fn WorkflowFunc, input any, o
684684
/******* STEP FUNCTIONS *******/
685685
/******************************/
686686

687-
type StepFunc func(ctx context.Context, input any) (any, error)
688-
type GenericStepFunc[P any, R any] func(ctx context.Context, input P) (R, error)
687+
type StepFunc func(ctx context.Context) (any, error)
688+
type GenericStepFunc[R any] func(ctx context.Context) (R, error)
689689

690690
const StepParamsKey DBOSContextKey = "stepParams"
691691

@@ -729,7 +729,7 @@ func setStepParamDefaults(params *StepParams, stepName string) *StepParams {
729729

730730
var typeErasedStepNameToStepName = make(map[string]string)
731731

732-
func RunAsStep[P any, R any](ctx DBOSContext, fn GenericStepFunc[P, R], input P) (R, error) {
732+
func RunAsStep[P any, R any](ctx DBOSContext, fn GenericStepFunc[R]) (R, error) {
733733
if ctx == nil {
734734
return *new(R), newStepExecutionError("", "", "ctx cannot be nil")
735735
}
@@ -738,46 +738,36 @@ func RunAsStep[P any, R any](ctx DBOSContext, fn GenericStepFunc[P, R], input P)
738738
return *new(R), newStepExecutionError("", "", "step function cannot be nil")
739739
}
740740

741-
// Type-erase the function based on its actual type
742-
typeErasedFn := StepFunc(func(ctx context.Context, input any) (any, error) {
743-
typedInput, ok := input.(P)
744-
if !ok {
745-
return nil, newStepExecutionError("", "", fmt.Sprintf("unexpected input type: expected %T, got %T", *new(P), input))
746-
}
747-
return fn(ctx, typedInput)
748-
})
741+
stepName := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name()
749742

750-
typeErasedStepNameToStepName[runtime.FuncForPC(reflect.ValueOf(typeErasedFn).Pointer()).Name()] = runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name()
743+
// Type-erase the function
744+
typeErasedFn := StepFunc(func(ctx context.Context) (any, error) { return fn(ctx) })
745+
typeErasedStepNameToStepName[runtime.FuncForPC(reflect.ValueOf(typeErasedFn).Pointer()).Name()] = stepName
751746

752747
// Call the executor method
753-
result, err := ctx.RunAsStep(ctx, typeErasedFn, input)
754-
if err != nil {
755-
// In case the errors comes from the DBOS step logic, the result will be nil and we must handle it
756-
if result == nil {
757-
return *new(R), err
758-
}
759-
return result.(R), err
748+
result, err := ctx.RunAsStep(ctx, typeErasedFn)
749+
// Step function could return a nil result
750+
if result == nil {
751+
return *new(R), err
760752
}
761-
762-
// Type-check and cast the result
753+
// Otherwise type-check and cast the result
763754
typedResult, ok := result.(R)
764755
if !ok {
765756
return *new(R), fmt.Errorf("unexpected result type: expected %T, got %T", *new(R), result)
766757
}
767-
768758
return typedResult, nil
769759
}
770760

771-
func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc, input any) (any, error) {
761+
func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc) (any, error) {
772762
// Get workflow state from context
773763
wfState, ok := c.Value(workflowStateKey).(*workflowState)
774764
if !ok || wfState == nil {
775765
// TODO: try to print step name
776766
return nil, newStepExecutionError("", "", "workflow state not found in context: are you running this step within a workflow?")
777767
}
778768

769+
// This should not happen when called from the package-level RunAsStep
779770
if fn == nil {
780-
// TODO: try to print step name
781771
return nil, newStepExecutionError(wfState.workflowID, "", "step function cannot be nil")
782772
}
783773

@@ -790,7 +780,7 @@ func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc, input any) (any, err
790780

791781
// If within a step, just run the function directly
792782
if wfState.isWithinStep {
793-
return fn(c, input)
783+
return fn(c)
794784
}
795785

796786
// Setup step state
@@ -819,7 +809,7 @@ func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc, input any) (any, err
819809
// Spawn a child DBOSContext with the step state
820810
stepCtx := WithValue(c, workflowStateKey, &stepState)
821811

822-
stepOutput, stepError := fn(stepCtx, input)
812+
stepOutput, stepError := fn(stepCtx)
823813

824814
// Retry if MaxRetries > 0 and the first execution failed
825815
var joinedErrors error
@@ -845,7 +835,7 @@ func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc, input any) (any, err
845835
}
846836

847837
// Execute the retry
848-
stepOutput, stepError = fn(stepCtx, input)
838+
stepOutput, stepError = fn(stepCtx)
849839

850840
// If successful, break
851841
if stepError == nil {

dbos/workflows_test.go

Lines changed: 53 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -35,19 +35,23 @@ func simpleWorkflowError(dbosCtx DBOSContext, input string) (int, error) {
3535
}
3636

3737
func simpleWorkflowWithStep(dbosCtx DBOSContext, input string) (string, error) {
38-
return RunAsStep(dbosCtx, simpleStep, input)
38+
return RunAsStep[string](dbosCtx, func(ctx context.Context) (string, error) {
39+
return simpleStep(ctx)
40+
})
3941
}
4042

41-
func simpleStep(ctx context.Context, input string) (string, error) {
43+
func simpleStep(_ context.Context) (string, error) {
4244
return "from step", nil
4345
}
4446

45-
func simpleStepError(ctx context.Context, input string) (string, error) {
47+
func simpleStepError(_ context.Context) (string, error) {
4648
return "", fmt.Errorf("step failure")
4749
}
4850

4951
func simpleWorkflowWithStepError(dbosCtx DBOSContext, input string) (string, error) {
50-
return RunAsStep(dbosCtx, simpleStepError, input)
52+
return RunAsStep[string](dbosCtx, func(ctx context.Context) (string, error) {
53+
return simpleStepError(ctx)
54+
})
5155
}
5256

5357
// idempotencyWorkflow increments a global counter and returns the input
@@ -292,38 +296,44 @@ func TestWorkflowsRegistration(t *testing.T) {
292296
}
293297
}
294298

295-
func stepWithinAStep(ctx context.Context, input string) (string, error) {
296-
return simpleStep(ctx, input)
299+
func stepWithinAStep(ctx context.Context) (string, error) {
300+
return simpleStep(ctx)
297301
}
298302

299303
func stepWithinAStepWorkflow(dbosCtx DBOSContext, input string) (string, error) {
300-
return RunAsStep(dbosCtx, stepWithinAStep, input)
304+
return RunAsStep[string](dbosCtx, func(ctx context.Context) (string, error) {
305+
return stepWithinAStep(ctx)
306+
})
301307
}
302308

303309
// Global counter for retry testing
304310
var stepRetryAttemptCount int
305311

306-
func stepRetryAlwaysFailsStep(ctx context.Context, input string) (string, error) {
312+
func stepRetryAlwaysFailsStep(ctx context.Context) (string, error) {
307313
stepRetryAttemptCount++
308314
return "", fmt.Errorf("always fails - attempt %d", stepRetryAttemptCount)
309315
}
310316

311317
var stepIdempotencyCounter int
312318

313-
func stepIdempotencyTest(ctx context.Context, input int) (string, error) {
319+
func stepIdempotencyTest(ctx context.Context) (string, error) {
314320
stepIdempotencyCounter++
315321
return "", nil
316322
}
317323

318324
func stepRetryWorkflow(dbosCtx DBOSContext, input string) (string, error) {
319-
RunAsStep(dbosCtx, stepIdempotencyTest, 1)
325+
RunAsStep[int](dbosCtx, func(ctx context.Context) (string, error) {
326+
return stepIdempotencyTest(ctx)
327+
})
320328
stepCtx := WithValue(dbosCtx, StepParamsKey, &StepParams{
321329
MaxRetries: 5,
322330
BaseInterval: 1 * time.Millisecond,
323331
MaxInterval: 10 * time.Millisecond,
324332
})
325333

326-
return RunAsStep(stepCtx, stepRetryAlwaysFailsStep, input)
334+
return RunAsStep[string](stepCtx, func(ctx context.Context) (string, error) {
335+
return stepRetryAlwaysFailsStep(ctx)
336+
})
327337
}
328338

329339
func TestSteps(t *testing.T) {
@@ -335,7 +345,9 @@ func TestSteps(t *testing.T) {
335345

336346
t.Run("StepsMustRunInsideWorkflows", func(t *testing.T) {
337347
// Attempt to run a step outside of a workflow context
338-
_, err := RunAsStep(dbosCtx, simpleStep, "test")
348+
_, err := RunAsStep[int](dbosCtx, func(ctx context.Context) (string, error) {
349+
return simpleStep(ctx)
350+
})
339351
if err == nil {
340352
t.Fatal("expected error when running step outside of workflow context, but got none")
341353
}
@@ -470,7 +482,9 @@ func TestChildWorkflow(t *testing.T) {
470482
return "", fmt.Errorf("expected childWf workflow ID to be %s, got %s", expectedCurrentID, workflowID)
471483
}
472484
// Steps of a child workflow start with an incremented step ID, because the first step ID is allocated to the child workflow
473-
return RunAsStep(dbosCtx, simpleStep, "")
485+
return RunAsStep[string](dbosCtx, func(ctx context.Context) (string, error) {
486+
return simpleStep(ctx)
487+
})
474488
}
475489
RegisterWorkflow(dbosCtx, childWf)
476490

@@ -644,7 +658,9 @@ func TestChildWorkflow(t *testing.T) {
644658
customChildID := uuid.NewString()
645659

646660
simpleChildWf := func(dbosCtx DBOSContext, input string) (string, error) {
647-
return RunAsStep(dbosCtx, simpleStep, input)
661+
return RunAsStep[string](dbosCtx, func(ctx context.Context) (string, error) {
662+
return simpleStep(ctx)
663+
})
648664
}
649665
RegisterWorkflow(dbosCtx, simpleChildWf)
650666

@@ -713,23 +729,29 @@ func TestChildWorkflow(t *testing.T) {
713729
// Idempotency workflows moved to test functions
714730

715731
func idempotencyWorkflow(dbosCtx DBOSContext, input string) (string, error) {
716-
RunAsStep(dbosCtx, incrementCounter, int64(1))
732+
RunAsStep[int64](dbosCtx, func(ctx context.Context) (int64, error) {
733+
return incrementCounter(ctx, int64(1))
734+
})
717735
return input, nil
718736
}
719737

720738
var blockingStepStopEvent *Event
721739

722-
func blockingStep(ctx context.Context, input string) (string, error) {
740+
func blockingStep(_ context.Context) (string, error) {
723741
blockingStepStopEvent.Wait()
724742
return "", nil
725743
}
726744

727745
var idempotencyWorkflowWithStepEvent *Event
728746

729747
func idempotencyWorkflowWithStep(dbosCtx DBOSContext, input string) (int64, error) {
730-
RunAsStep(dbosCtx, incrementCounter, int64(1))
748+
RunAsStep[int64](dbosCtx, func(ctx context.Context) (int64, error) {
749+
return incrementCounter(ctx, int64(1))
750+
})
731751
idempotencyWorkflowWithStepEvent.Set()
732-
RunAsStep(dbosCtx, blockingStep, input)
752+
RunAsStep[int](dbosCtx, func(ctx context.Context) (string, error) {
753+
return blockingStep(ctx)
754+
})
733755
return idempotencyCounter, nil
734756
}
735757

@@ -1253,7 +1275,9 @@ func stepThatCallsSend(ctx context.Context, input sendWorkflowInput) (string, er
12531275
}
12541276

12551277
func workflowThatCallsSendInStep(ctx DBOSContext, input sendWorkflowInput) (string, error) {
1256-
return RunAsStep(ctx, stepThatCallsSend, input)
1278+
return RunAsStep[sendWorkflowInput](ctx, func(context context.Context) (string, error) {
1279+
return stepThatCallsSend(context, input)
1280+
})
12571281
}
12581282

12591283
type sendRecvType struct {
@@ -2193,7 +2217,7 @@ func TestWorkflowTimeout(t *testing.T) {
21932217
}
21942218
})
21952219

2196-
waitForCancelStep := func(ctx context.Context, _ string) (string, error) {
2220+
waitForCancelStep := func(ctx context.Context) (string, error) {
21972221
// This step will trigger cancellation of the entire workflow context
21982222
<-ctx.Done()
21992223
if !errors.Is(ctx.Err(), context.Canceled) && !errors.Is(ctx.Err(), context.DeadlineExceeded) {
@@ -2203,7 +2227,9 @@ func TestWorkflowTimeout(t *testing.T) {
22032227
}
22042228

22052229
waitForCancelWorkflowWithStep := func(ctx DBOSContext, _ string) (string, error) {
2206-
return RunAsStep(ctx, waitForCancelStep, "trigger-cancellation")
2230+
return RunAsStep[sendWorkflowInput](ctx, func(context context.Context) (string, error) {
2231+
return waitForCancelStep(context)
2232+
})
22072233
}
22082234
RegisterWorkflow(dbosCtx, waitForCancelWorkflowWithStep)
22092235

@@ -2240,7 +2266,9 @@ func TestWorkflowTimeout(t *testing.T) {
22402266
// The timeout will trigger a step error, the workflow can do whatever it wants with that error
22412267
stepCtx, stepCancelFunc := WithTimeout(ctx, 1*time.Millisecond)
22422268
defer stepCancelFunc() // Ensure we clean up the context
2243-
_, err := RunAsStep(stepCtx, waitForCancelStep, "short-step-timeout")
2269+
_, err := RunAsStep[string](stepCtx, func(context context.Context) (string, error) {
2270+
return waitForCancelStep(context)
2271+
})
22442272
if !errors.Is(err, context.DeadlineExceeded) {
22452273
t.Fatalf("expected step to timeout, got: %v", err)
22462274
}
@@ -2287,7 +2315,9 @@ func TestWorkflowTimeout(t *testing.T) {
22872315
// This workflow will run a step that is not cancelable.
22882316
// What this means is the workflow *will* be cancelled, but the step will run normally
22892317
stepCtx := WithoutCancel(ctx)
2290-
res, err := RunAsStep(stepCtx, detachedStep, timeout*2)
2318+
res, err := RunAsStep[time.Duration](stepCtx, func(context context.Context) (string, error) {
2319+
return detachedStep(context, timeout*2)
2320+
})
22912321
if err != nil {
22922322
t.Fatalf("failed to run detached step: %v", err)
22932323
}

0 commit comments

Comments
 (0)