Skip to content

Commit 0ad397a

Browse files
committed
fix steps output type gob registration and add a test
1 parent 19994c9 commit 0ad397a

File tree

2 files changed

+115
-4
lines changed

2 files changed

+115
-4
lines changed

dbos/workflow.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1041,6 +1041,10 @@ func RunAsStep[R any](ctx DBOSContext, fn Step[R], opts ...StepOption) (R, error
10411041
return *new(R), newStepExecutionError("", "", "step function cannot be nil")
10421042
}
10431043

1044+
// Register the output type for gob encoding
1045+
var r R
1046+
gob.Register(r)
1047+
10441048
// Append WithStepName option to ensure the step name is set. This will not erase a user-provided step name
10451049
stepName := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name()
10461050
opts = append(opts, WithStepName(stepName))
@@ -1490,9 +1494,9 @@ func (c *dbosContext) CancelWorkflow(_ DBOSContext, workflowID string) error {
14901494
workflowState, ok := c.Value(workflowStateKey).(*workflowState)
14911495
isWithinWorkflow := ok && workflowState != nil
14921496
if isWithinWorkflow {
1493-
_, err := RunAsStep(c, func(ctx context.Context) (any, error) {
1497+
_, err := RunAsStep(c, func(ctx context.Context) (string, error) {
14941498
err := c.systemDB.cancelWorkflow(ctx, workflowID)
1495-
return nil, err
1499+
return "", err
14961500
}, WithStepName("DBOS.cancelWorkflow"))
14971501
return err
14981502
} else {
@@ -1527,9 +1531,9 @@ func (c *dbosContext) ResumeWorkflow(_ DBOSContext, workflowID string) (Workflow
15271531
isWithinWorkflow := ok && workflowState != nil
15281532
var err error
15291533
if isWithinWorkflow {
1530-
_, err = RunAsStep(c, func(ctx context.Context) (any, error) {
1534+
_, err = RunAsStep(c, func(ctx context.Context) (string, error) {
15311535
err := c.systemDB.resumeWorkflow(ctx, workflowID)
1532-
return nil, err
1536+
return "", err
15331537
}, WithStepName("DBOS.resumeWorkflow"))
15341538
} else {
15351539
err = c.systemDB.resumeWorkflow(c, workflowID)

dbos/workflows_test.go

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -567,6 +567,113 @@ func TestSteps(t *testing.T) {
567567
assert.Equal(t, "MyCustomStep2", steps[1].StepName, "expected second step to have custom name")
568568
assert.Equal(t, 1, steps[1].StepID)
569569
})
570+
571+
t.Run("stepsOutputEncoding", func(t *testing.T) {
572+
// Define user-defined types for testing serialization
573+
type StepInput struct {
574+
Name string `json:"name"`
575+
Count int `json:"count"`
576+
Active bool `json:"active"`
577+
Metadata map[string]string `json:"metadata"`
578+
CreatedAt time.Time `json:"created_at"`
579+
}
580+
581+
type StepOutput struct {
582+
ProcessedName string `json:"processed_name"`
583+
TotalCount int `json:"total_count"`
584+
Success bool `json:"success"`
585+
ProcessedAt time.Time `json:"processed_at"`
586+
Details []string `json:"details"`
587+
}
588+
589+
// Create a step function that accepts StepInput and returns StepOutput
590+
processUserObjectStep := func(_ context.Context, input StepInput) (StepOutput, error) {
591+
// Process the input and create output
592+
output := StepOutput{
593+
ProcessedName: fmt.Sprintf("Processed_%s", input.Name),
594+
TotalCount: input.Count * 2,
595+
Success: input.Active,
596+
ProcessedAt: time.Now(),
597+
Details: []string{"step1", "step2", "step3"},
598+
}
599+
600+
// Verify input was correctly deserialized
601+
if input.Metadata == nil {
602+
return StepOutput{}, fmt.Errorf("metadata map was not properly deserialized")
603+
}
604+
605+
return output, nil
606+
}
607+
608+
// Create a workflow that uses the step with user-defined objects
609+
userObjectWorkflow := func(dbosCtx DBOSContext, workflowInput string) (string, error) {
610+
// Create input for the step
611+
stepInput := StepInput{
612+
Name: workflowInput,
613+
Count: 42,
614+
Active: true,
615+
Metadata: map[string]string{
616+
"key1": "value1",
617+
"key2": "value2",
618+
},
619+
CreatedAt: time.Now(),
620+
}
621+
622+
// Run the step with user-defined input and output
623+
output, err := RunAsStep(dbosCtx, func(ctx context.Context) (StepOutput, error) {
624+
return processUserObjectStep(ctx, stepInput)
625+
})
626+
if err != nil {
627+
return "", fmt.Errorf("step failed: %w", err)
628+
}
629+
630+
// Verify the output was correctly returned
631+
if output.ProcessedName == "" {
632+
return "", fmt.Errorf("output ProcessedName is empty")
633+
}
634+
if output.TotalCount != 84 {
635+
return "", fmt.Errorf("expected TotalCount to be 84, got %d", output.TotalCount)
636+
}
637+
if len(output.Details) != 3 {
638+
return "", fmt.Errorf("expected 3 details, got %d", len(output.Details))
639+
}
640+
641+
return "", nil
642+
}
643+
644+
// Register the workflow
645+
RegisterWorkflow(dbosCtx, userObjectWorkflow)
646+
647+
// Execute the workflow
648+
handle, err := RunWorkflow(dbosCtx, userObjectWorkflow, "TestObject")
649+
require.NoError(t, err, "failed to run workflow with user-defined objects")
650+
651+
// Get the result
652+
_, err = handle.GetResult()
653+
require.NoError(t, err, "failed to get result from workflow")
654+
655+
// Verify the step was recorded
656+
steps, err := GetWorkflowSteps(dbosCtx, handle.GetWorkflowID())
657+
require.NoError(t, err, "failed to get workflow steps")
658+
require.Len(t, steps, 1, "expected 1 step")
659+
660+
// Verify step output was properly serialized and stored
661+
step := steps[0]
662+
require.NotNil(t, step.Output, "step output should not be nil")
663+
assert.Nil(t, step.Error)
664+
665+
// Deserialize the output from the database to verify proper encoding
666+
storedOutput, ok := step.Output.(StepOutput)
667+
require.True(t, ok, "failed to cast step output to StepOutput")
668+
669+
// Verify all fields were correctly serialized and deserialized
670+
assert.Equal(t, "Processed_TestObject", storedOutput.ProcessedName, "ProcessedName not correctly serialized")
671+
assert.Equal(t, 84, storedOutput.TotalCount, "TotalCount not correctly serialized")
672+
assert.True(t, storedOutput.Success, "Success flag not correctly serialized")
673+
assert.Len(t, storedOutput.Details, 3, "Details array length incorrect")
674+
assert.Equal(t, []string{"step1", "step2", "step3"}, storedOutput.Details, "Details array not correctly serialized")
675+
assert.False(t, storedOutput.ProcessedAt.IsZero(), "ProcessedAt timestamp should not be zero")
676+
})
570677
}
571678

572679
func TestChildWorkflow(t *testing.T) {

0 commit comments

Comments
 (0)