Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions dbos/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package dbos

import (
"context"
"encoding/gob"
"errors"
"fmt"
"log/slog"
Expand Down Expand Up @@ -241,10 +240,16 @@ func Enqueue[P any, R any](c Client, queueName, workflowName string, input P, op
}

// Register the input and outputs for gob encoding
var logger *slog.Logger
if cl, ok := c.(*client); ok {
if ctx, ok := cl.dbosCtx.(*dbosContext); ok {
logger = ctx.logger
}
}
var typedInput P
gob.Register(typedInput)
safeGobRegister(typedInput, logger)
var typedOutput R
gob.Register(typedOutput)
safeGobRegister(typedOutput, logger)

// Call the interface method with the same signature
handle, err := c.Enqueue(queueName, workflowName, input, opts...)
Expand Down
7 changes: 3 additions & 4 deletions dbos/dbos.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package dbos
import (
"context"
"crypto/sha256"
"encoding/gob"
"encoding/hex"
"errors"
"fmt"
Expand Down Expand Up @@ -317,11 +316,11 @@ func NewDBOSContext(ctx context.Context, inputConfig Config) (DBOSContext, error

// Register types we serialize with gob
var t time.Time
gob.Register(t)
safeGobRegister(t, initExecutor.logger)
var ws []WorkflowStatus
gob.Register(ws)
safeGobRegister(ws, initExecutor.logger)
var si []StepInfo
gob.Register(si)
safeGobRegister(si, initExecutor.logger)

// Initialize global variables from processed config (already handles env vars and defaults)
initExecutor.applicationVersion = config.ApplicationVersion
Expand Down
27 changes: 27 additions & 0 deletions dbos/serialization.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"encoding/base64"
"encoding/gob"
"fmt"
"log/slog"
"strings"
)

func serialize(data any) (string, error) {
Expand Down Expand Up @@ -39,3 +41,28 @@ func deserialize(data *string) (any, error) {

return result, nil
}

// safeGobRegister attempts to register a type with gob, recovering only from
// panics caused by duplicate type/name registrations (e.g., registering both T and *T).
// These specific conflicts don't affect encoding/decoding correctness, so they're safe to ignore.
// Other panics (like register `any`) are real errors and will propagate.
func safeGobRegister(value any, logger *slog.Logger) {
defer func() {
if r := recover(); r != nil {
if errStr, ok := r.(string); ok {
// Check if this is one of the two specific duplicate registration errors we want to ignore
// See https://cs.opensource.google/go/go/+/refs/tags/go1.25.1:src/encoding/gob/type.go;l=832
if strings.Contains(errStr, "gob: registering duplicate types for") ||
strings.Contains(errStr, "gob: registering duplicate names for") {
if logger != nil {
logger.Debug("gob registration conflict", "type", fmt.Sprintf("%T", value), "error", r)
}
return
}
}
// Re-panic for any other errors
panic(r)
}
}()
gob.Register(value)
}
46 changes: 37 additions & 9 deletions dbos/workflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ package dbos

import (
"context"
"encoding/gob"
"errors"
"fmt"
"log/slog"
"math"
"reflect"
"runtime"
Expand Down Expand Up @@ -437,10 +437,14 @@ func RegisterWorkflow[P any, R any](ctx DBOSContext, fn Workflow[P, R], opts ...
fqn := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name()

// Registry the input/output types for gob encoding
var logger *slog.Logger
if c, ok := ctx.(*dbosContext); ok {
logger = c.logger
}
var p P
var r R
gob.Register(p)
gob.Register(r)
safeGobRegister(p, logger)
safeGobRegister(r, logger)

// Register a type-erased version of the durable workflow for recovery
typedErasedWorkflow := WorkflowFunc(func(ctx DBOSContext, input any) (any, error) {
Expand Down Expand Up @@ -1041,8 +1045,12 @@ func RunAsStep[R any](ctx DBOSContext, fn Step[R], opts ...StepOption) (R, error
}

// Register the output type for gob encoding
var logger *slog.Logger
if c, ok := ctx.(*dbosContext); ok {
logger = c.logger
}
var r R
gob.Register(r)
safeGobRegister(r, logger)

// Append WithStepName option to ensure the step name is set. This will not erase a user-provided step name
stepName := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name()
Expand Down Expand Up @@ -1205,8 +1213,12 @@ func Send[P any](ctx DBOSContext, destinationID string, message P, topic string)
if ctx == nil {
return errors.New("ctx cannot be nil")
}
var logger *slog.Logger
if c, ok := ctx.(*dbosContext); ok {
logger = c.logger
}
var typedMessage P
gob.Register(typedMessage)
safeGobRegister(typedMessage, logger)
return ctx.Send(ctx, destinationID, message, topic)
}

Expand Down Expand Up @@ -1282,8 +1294,12 @@ func SetEvent[P any](ctx DBOSContext, key string, message P) error {
if ctx == nil {
return errors.New("ctx cannot be nil")
}
var logger *slog.Logger
if c, ok := ctx.(*dbosContext); ok {
logger = c.logger
}
var typedMessage P
gob.Register(typedMessage)
safeGobRegister(typedMessage, logger)
return ctx.SetEvent(ctx, key, message)
}

Expand Down Expand Up @@ -1476,8 +1492,12 @@ func RetrieveWorkflow[R any](ctx DBOSContext, workflowID string) (WorkflowHandle
}

// Register the output for gob encoding
var logger *slog.Logger
if c, ok := ctx.(*dbosContext); ok {
logger = c.logger
}
var r R
gob.Register(r)
safeGobRegister(r, logger)

// Call the interface method
handle, err := ctx.RetrieveWorkflow(ctx, workflowID)
Expand Down Expand Up @@ -1570,8 +1590,12 @@ func ResumeWorkflow[R any](ctx DBOSContext, workflowID string) (WorkflowHandle[R
}

// Register the output for gob encoding
var logger *slog.Logger
if c, ok := ctx.(*dbosContext); ok {
logger = c.logger
}
var r R
gob.Register(r)
safeGobRegister(r, logger)

_, err := ctx.ResumeWorkflow(ctx, workflowID)
if err != nil {
Expand Down Expand Up @@ -1662,8 +1686,12 @@ func ForkWorkflow[R any](ctx DBOSContext, input ForkWorkflowInput) (WorkflowHand
}

// Register the output for gob encoding
var logger *slog.Logger
if c, ok := ctx.(*dbosContext); ok {
logger = c.logger
}
var r R
gob.Register(r)
safeGobRegister(r, logger)

handle, err := ctx.ForkWorkflow(ctx, input)
if err != nil {
Expand Down
172 changes: 172 additions & 0 deletions dbos/workflows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"reflect"
"runtime"
"strings"
"sync"
"sync/atomic"
"testing"
Expand Down Expand Up @@ -353,6 +354,177 @@ func TestWorkflowsRegistration(t *testing.T) {
}()
RegisterWorkflow(freshCtx, simpleWorkflow)
})

t.Run("SafeGobRegister", func(t *testing.T) {
// Create a fresh DBOS context for this test
freshCtx := setupDBOS(t, false, true) // Don't reset DB but do check for leaks

// Test 1: Basic type vs pointer conflicts
type TestType struct {
Value string
}

// Register workflows that use the same type to trigger potential gob conflicts
// The safeGobRegister calls within RegisterWorkflow should handle the conflicts
workflow1 := func(ctx DBOSContext, input TestType) (TestType, error) {
return input, nil
}
workflow2 := func(ctx DBOSContext, input *TestType) (*TestType, error) {
return input, nil
}

// Both registrations should succeed despite using conflicting types (T and *T)
RegisterWorkflow(freshCtx, workflow1)
RegisterWorkflow(freshCtx, workflow2)

// Test 2: Multiple workflows with the same types (duplicate registrations)
workflow3 := func(ctx DBOSContext, input TestType) (TestType, error) {
return TestType{Value: input.Value + "-modified"}, nil
}
workflow4 := func(ctx DBOSContext, input TestType) (TestType, error) {
return TestType{Value: input.Value + "-another"}, nil
}

// These should succeed even though TestType is already registered
RegisterWorkflow(freshCtx, workflow3)
RegisterWorkflow(freshCtx, workflow4)

// Test 3: Nested structs
type InnerType struct {
ID int
}
type OuterType struct {
Inner InnerType
Name string
}

workflow5 := func(ctx DBOSContext, input OuterType) (OuterType, error) {
return input, nil
}
workflow6 := func(ctx DBOSContext, input *OuterType) (*OuterType, error) {
return input, nil
}

RegisterWorkflow(freshCtx, workflow5)
RegisterWorkflow(freshCtx, workflow6)

// Test 4: Slice and map types
workflow7 := func(ctx DBOSContext, input []TestType) ([]TestType, error) {
return input, nil
}
workflow8 := func(ctx DBOSContext, input []*TestType) ([]*TestType, error) {
return input, nil
}
workflow9 := func(ctx DBOSContext, input map[string]TestType) (map[string]TestType, error) {
return input, nil
}
workflow10 := func(ctx DBOSContext, input map[string]*TestType) (map[string]*TestType, error) {
return input, nil
}

RegisterWorkflow(freshCtx, workflow7)
RegisterWorkflow(freshCtx, workflow8)
RegisterWorkflow(freshCtx, workflow9)
RegisterWorkflow(freshCtx, workflow10)

// Launch and verify the system still works
err := Launch(freshCtx)
require.NoError(t, err, "failed to launch DBOS after gob conflict handling")
defer Shutdown(freshCtx, 10*time.Second)

// Test all registered workflows to ensure they work correctly

// Run workflow1 with value type
testValue := TestType{Value: "test"}
handle1, err := RunWorkflow(freshCtx, workflow1, testValue)
require.NoError(t, err, "failed to run workflow1")
result1, err := handle1.GetResult()
require.NoError(t, err, "failed to get result from workflow1")
assert.Equal(t, testValue, result1, "unexpected result from workflow1")

// Run workflow2 with pointer type
testPointer := &TestType{Value: "pointer"}
handle2, err := RunWorkflow(freshCtx, workflow2, testPointer)
require.NoError(t, err, "failed to run workflow2")
result2, err := handle2.GetResult()
require.NoError(t, err, "failed to get result from workflow2")
assert.Equal(t, testPointer, result2, "unexpected result from workflow2")

// Run workflow3 with modified output
handle3, err := RunWorkflow(freshCtx, workflow3, testValue)
require.NoError(t, err, "failed to run workflow3")
result3, err := handle3.GetResult()
require.NoError(t, err, "failed to get result from workflow3")
assert.Equal(t, TestType{Value: "test-modified"}, result3, "unexpected result from workflow3")

// Run workflow5 with nested struct
testOuter := OuterType{Inner: InnerType{ID: 42}, Name: "test"}
handle5, err := RunWorkflow(freshCtx, workflow5, testOuter)
require.NoError(t, err, "failed to run workflow5")
result5, err := handle5.GetResult()
require.NoError(t, err, "failed to get result from workflow5")
assert.Equal(t, testOuter, result5, "unexpected result from workflow5")

// Run workflow6 with nested struct pointer
testOuterPtr := &OuterType{Inner: InnerType{ID: 43}, Name: "test-ptr"}
handle6, err := RunWorkflow(freshCtx, workflow6, testOuterPtr)
require.NoError(t, err, "failed to run workflow6")
result6, err := handle6.GetResult()
require.NoError(t, err, "failed to get result from workflow6")
assert.Equal(t, testOuterPtr, result6, "unexpected result from workflow6")

// Run workflow7 with slice type
testSlice := []TestType{{Value: "a"}, {Value: "b"}}
handle7, err := RunWorkflow(freshCtx, workflow7, testSlice)
require.NoError(t, err, "failed to run workflow7")
result7, err := handle7.GetResult()
require.NoError(t, err, "failed to get result from workflow7")
assert.Equal(t, testSlice, result7, "unexpected result from workflow7")

// Run workflow8 with pointer slice type
testPtrSlice := []*TestType{{Value: "a"}, {Value: "b"}}
handle8, err := RunWorkflow(freshCtx, workflow8, testPtrSlice)
require.NoError(t, err, "failed to run workflow8")
result8, err := handle8.GetResult()
require.NoError(t, err, "failed to get result from workflow8")
assert.Equal(t, testPtrSlice, result8, "unexpected result from workflow8")

// Run workflow9 with map type
testMap := map[string]TestType{"key1": {Value: "value1"}}
handle9, err := RunWorkflow(freshCtx, workflow9, testMap)
require.NoError(t, err, "failed to run workflow9")
result9, err := handle9.GetResult()
require.NoError(t, err, "failed to get result from workflow9")
assert.Equal(t, testMap, result9, "unexpected result from workflow9")

// Run workflow10 with pointer map type
testPtrMap := map[string]*TestType{"key1": {Value: "value1"}}
handle10, err := RunWorkflow(freshCtx, workflow10, testPtrMap)
require.NoError(t, err, "failed to run workflow10")
result10, err := handle10.GetResult()
require.NoError(t, err, "failed to get result from workflow10")
assert.Equal(t, testPtrMap, result10, "unexpected result from workflow10")

t.Run("validPanic", func(t *testing.T) {
// Verify that non-duplicate registration panics are still propagated
workflow11 := func(ctx DBOSContext, input any) (any, error) {
return input, nil
}

// This should panic during registration because interface{} creates a nil value
// which gob.Register cannot handle
defer func() {
r := recover()
require.NotNil(t, r, "expected panic from interface{} registration but got none")
// Verify it's not a duplicate registration error (which would be caught)
if errStr, ok := r.(string); ok {
assert.False(t, strings.Contains(errStr, "gob: registering duplicate"),
"panic should not be a duplicate registration error, got: %v", r)
}
}()
RegisterWorkflow(freshCtx, workflow11) // This should panic
})
})
}

func stepWithinAStep(ctx context.Context) (string, error) {
Expand Down
Loading