Skip to content

Commit 4b400c5

Browse files
committed
differentiate nil values and zero values
1 parent 6fe311c commit 4b400c5

File tree

6 files changed

+109
-66
lines changed

6 files changed

+109
-66
lines changed

dbos/client.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ func Enqueue[P any, R any](c Client, queueName, workflowName string, input P, op
249249
}
250250

251251
// Call the interface method with the same signature
252-
handle, err := c.Enqueue(queueName, workflowName, &encodedInput, opts...)
252+
handle, err := c.Enqueue(queueName, workflowName, encodedInput, opts...)
253253
if err != nil {
254254
return nil, err
255255
}

dbos/serialization.go

Lines changed: 52 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import (
88
)
99

1010
type serializer[T any] interface {
11-
Encode(data T) (string, error)
11+
Encode(data T) (*string, error)
1212
Decode(data *string) (T, error)
1313
}
1414

@@ -18,23 +18,38 @@ func newJSONSerializer[T any]() serializer[T] {
1818
return &jsonSerializer[T]{}
1919
}
2020

21-
func (j *jsonSerializer[T]) Encode(data T) (string, error) {
22-
if isNilOrZeroValue(data) {
23-
// For nil values, encode an empty byte slice directly to base64
24-
return base64.StdEncoding.EncodeToString([]byte{}), nil
21+
func (j *jsonSerializer[T]) Encode(data T) (*string, error) {
22+
// Check if the value is nil (for pointer types, slice, map, etc.)
23+
if isNilValue(data) {
24+
// For nil values, return nil pointer
25+
return nil, nil
26+
}
27+
28+
// Check if the value is a zero value (but not nil)
29+
if isZeroValue(data) {
30+
// For zero values, encode an empty byte slice directly to base64
31+
emptyStr := base64.StdEncoding.EncodeToString([]byte{})
32+
return &emptyStr, nil
2533
}
2634

2735
jsonBytes, err := json.Marshal(data)
2836
if err != nil {
29-
return "", fmt.Errorf("failed to encode data: %w", err)
37+
return nil, fmt.Errorf("failed to encode data: %w", err)
3038
}
31-
return base64.StdEncoding.EncodeToString(jsonBytes), nil
39+
encodedStr := base64.StdEncoding.EncodeToString(jsonBytes)
40+
return &encodedStr, nil
3241
}
3342

3443
func (j *jsonSerializer[T]) Decode(data *string) (T, error) {
3544
var result T
3645

37-
if data == nil || *data == "" {
46+
// If data is a nil pointer, return nil (for pointer types) or zero value (for non-pointer types)
47+
if data == nil {
48+
return getNilOrZeroValue[T](), nil
49+
}
50+
51+
// If *data is an empty string, return zero value
52+
if *data == "" {
3853
return result, nil
3954
}
4055

@@ -43,7 +58,7 @@ func (j *jsonSerializer[T]) Decode(data *string) (T, error) {
4358
return result, fmt.Errorf("failed to decode base64 data: %w", err)
4459
}
4560

46-
// If decoded data is empty, it represents a nil value
61+
// If decoded data is empty, it represents a zero value
4762
if len(dataBytes) == 0 {
4863
return result, nil
4964
}
@@ -57,39 +72,45 @@ func (j *jsonSerializer[T]) Decode(data *string) (T, error) {
5772
return result, nil
5873
}
5974

60-
// isNilOrZeroValue checks if a value is nil (for pointer types, slice, map, etc.) or a zero value.
61-
func isNilOrZeroValue(v any) bool {
75+
// isNilValue checks if a value is nil (for pointer types, slice, map, etc.).
76+
func isNilValue(v any) bool {
6277
val := reflect.ValueOf(v)
6378
if !val.IsValid() {
6479
return true
6580
}
6681
switch val.Kind() {
67-
case reflect.Pointer, reflect.Slice, reflect.Map, reflect.Chan, reflect.Func:
82+
case reflect.Pointer, reflect.Slice, reflect.Map, reflect.Chan, reflect.Func, reflect.Interface:
6883
return val.IsNil()
6984
}
70-
// For other types, check if it's the zero value
71-
return val.IsZero()
85+
return false
7286
}
7387

74-
// IsNestedPointer checks if a type is a nested pointer (e.g., **int, ***int).
75-
// It returns false for non-pointer types and single-level pointers (*int).
76-
// It returns true for nested pointers with depth > 1.
77-
func IsNestedPointer(t reflect.Type) bool {
78-
if t == nil {
88+
// isZeroValue checks if a value is a zero value (but not nil).
89+
func isZeroValue(v any) bool {
90+
val := reflect.ValueOf(v)
91+
if !val.IsValid() {
7992
return false
8093
}
81-
82-
depth := 0
83-
currentType := t
84-
85-
// Count pointer indirection levels, break early if depth > 1
86-
for currentType != nil && currentType.Kind() == reflect.Pointer {
87-
depth++
88-
if depth > 1 {
89-
return true
90-
}
91-
currentType = currentType.Elem()
94+
// For pointer-like types, if it's not nil, it's not a zero value
95+
switch val.Kind() {
96+
case reflect.Pointer, reflect.Slice, reflect.Map, reflect.Chan, reflect.Func, reflect.Interface:
97+
return false
9298
}
99+
// For other types, check if it's the zero value
100+
return val.IsZero()
101+
}
93102

94-
return false
103+
// getNilOrZeroValue returns nil for pointer types, or zero value for non-pointer types.
104+
func getNilOrZeroValue[T any]() T {
105+
var result T
106+
resultType := reflect.TypeOf(result)
107+
if resultType == nil {
108+
return result
109+
}
110+
// If T is a pointer type, return nil
111+
if resultType.Kind() == reflect.Pointer {
112+
return reflect.Zero(resultType).Interface().(T)
113+
}
114+
// Otherwise return zero value
115+
return result
95116
}

dbos/serialization_test.go

Lines changed: 43 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"encoding/json"
66
"fmt"
7+
"reflect"
78
"testing"
89
"time"
910

@@ -26,7 +27,17 @@ func testAllSerializationPaths[T any](
2627
) {
2728
t.Helper()
2829

29-
isNilExpected := isNilOrZeroValue(input)
30+
// Check if input is nil (for pointer types, slice, map, etc.)
31+
val := reflect.ValueOf(input)
32+
isNilExpected := false
33+
if !val.IsValid() {
34+
isNilExpected = true
35+
} else {
36+
switch val.Kind() {
37+
case reflect.Pointer, reflect.Slice, reflect.Map, reflect.Chan, reflect.Func:
38+
isNilExpected = val.IsNil()
39+
}
40+
}
3041

3142
// Setup events for recovery
3243
startEvent := NewEvent()
@@ -98,17 +109,23 @@ func testAllSerializationPaths[T any](
98109
lastStep := steps[len(steps)-1]
99110
if isNilExpected {
100111
// Should be an empty string
101-
assert.Equal(t, "", lastStep.Output, "Step output should be an empty string")
112+
assert.Nil(t, lastStep.Output, "Step output should be nil")
102113
} else {
103114
require.NotNil(t, lastStep.Output)
104115
// GetWorkflowSteps returns a string (base64-decoded JSON)
105116
// Unmarshal the JSON string into type T
106117
strValue, ok := lastStep.Output.(string)
107118
require.True(t, ok, "Step output should be a string")
108-
var decodedOutput T
109-
err := json.Unmarshal([]byte(strValue), &decodedOutput)
110-
require.NoError(t, err, "Failed to unmarshal step output to type T")
111-
assert.Equal(t, expectedOutput, decodedOutput, "Step output should match expected output")
119+
// We encode zero values as empty strings. End users are expected to handle this.
120+
if strValue == "" {
121+
var zero T
122+
assert.Equal(t, zero, expectedOutput, "Step output should be the zero value of type T")
123+
} else {
124+
var decodedOutput T
125+
err := json.Unmarshal([]byte(strValue), &decodedOutput)
126+
require.NoError(t, err, "Failed to unmarshal step output to type T")
127+
assert.Equal(t, expectedOutput, decodedOutput, "Step output should match expected output")
128+
}
112129
}
113130
assert.Nil(t, lastStep.Error)
114131
}
@@ -124,8 +141,8 @@ func testAllSerializationPaths[T any](
124141
wf := wfs[0]
125142
if isNilExpected {
126143
// Should be an empty string
127-
assert.Equal(t, "", wf.Input, "Workflow input should be an empty string")
128-
assert.Equal(t, "", wf.Output, "Workflow output should be an empty string")
144+
assert.Nil(t, wf.Input, "Workflow input should be nil")
145+
assert.Nil(t, wf.Output, "Workflow output should be nil")
129146
} else {
130147
require.NotNil(t, wf.Input)
131148
require.NotNil(t, wf.Output)
@@ -137,16 +154,25 @@ func testAllSerializationPaths[T any](
137154
outputStr, ok := wf.Output.(string)
138155
require.True(t, ok, "Workflow output should be a string")
139156

140-
var decodedInput T
141-
err := json.Unmarshal([]byte(inputStr), &decodedInput)
142-
require.NoError(t, err, "Failed to unmarshal workflow input to type T")
143-
144-
var decodedOutput T
145-
err = json.Unmarshal([]byte(outputStr), &decodedOutput)
146-
require.NoError(t, err, "Failed to unmarshal workflow output to type T")
157+
if inputStr == "" {
158+
var zero T
159+
assert.Equal(t, zero, input, "Workflow input should be the zero value of type T")
160+
} else {
161+
var decodedInput T
162+
err := json.Unmarshal([]byte(inputStr), &decodedInput)
163+
require.NoError(t, err, "Failed to unmarshal workflow input to type T")
164+
assert.Equal(t, input, decodedInput, "Workflow input should match input")
165+
}
147166

148-
assert.Equal(t, input, decodedInput, "Workflow input should match input")
149-
assert.Equal(t, expectedOutput, decodedOutput, "Workflow output should match expected output")
167+
if outputStr == "" {
168+
var zero T
169+
assert.Equal(t, zero, expectedOutput, "Workflow output should be the zero value of type T")
170+
} else {
171+
var decodedOutput T
172+
err = json.Unmarshal([]byte(outputStr), &decodedOutput)
173+
require.NoError(t, err, "Failed to unmarshal workflow output to type T")
174+
assert.Equal(t, expectedOutput, decodedOutput, "Workflow output should match expected output")
175+
}
150176
}
151177
})
152178
}
@@ -781,13 +807,6 @@ func TestSerializer(t *testing.T) {
781807
testSendRecv(t, executor, serializerIntPtrSenderWorkflow, serializerIntPtrReceiverWorkflow, input, "typed-intptr-set-sender-wf")
782808
testSetGetEvent(t, executor, serializerIntPtrSetEventWorkflow, serializerIntPtrGetEventWorkflow, input, "typed-intptr-set-setevent-wf", "typed-intptr-set-getevent-wf")
783809
})
784-
785-
// Test *int (pointer type, nil)
786-
t.Run("IntPtrNil", func(t *testing.T) {
787-
var input *int = nil
788-
testSendRecv(t, executor, serializerIntPtrSenderWorkflow, serializerIntPtrReceiverWorkflow, input, "typed-intptr-nil-sender-wf")
789-
testSetGetEvent(t, executor, serializerIntPtrSetEventWorkflow, serializerIntPtrGetEventWorkflow, input, "typed-intptr-nil-setevent-wf", "typed-intptr-nil-getevent-wf")
790-
})
791810
})
792811

793812
// Test queued workflow with TestWorkflowData type

dbos/system_database.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1544,11 +1544,10 @@ func (s *sysDB) sleep(ctx context.Context, input sleepInput) (time.Duration, err
15441544

15451545
// Serialize the end time before recording
15461546
serializer := newJSONSerializer[time.Time]()
1547-
encodedEndTimeStr, serErr := serializer.Encode(endTime)
1547+
encodedEndTime, serErr := serializer.Encode(endTime)
15481548
if serErr != nil {
15491549
return 0, fmt.Errorf("failed to serialize sleep end time: %w", serErr)
15501550
}
1551-
encodedEndTime := &encodedEndTimeStr
15521551

15531552
// Record the operation result with the calculated end time
15541553
recordInput := recordOperationResultDBInput{

dbos/workflow.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ func (h *workflowHandle[R]) processOutcome(outcome workflowOutcome[R]) (R, error
223223
parentWorkflowID: workflowState.workflowID,
224224
childWorkflowID: h.workflowID,
225225
stepID: workflowState.nextStepID(),
226-
output: &encodedOutput,
226+
output: encodedOutput,
227227
err: outcome.err,
228228
}
229229
recordResultErr := retry(h.dbosContext, func() error {
@@ -877,7 +877,7 @@ func (c *dbosContext) RunWorkflow(_ DBOSContext, fn WorkflowFunc, input any, opt
877877
CreatedAt: time.Now(),
878878
Deadline: deadline,
879879
Timeout: timeout,
880-
Input: &encodedInput,
880+
Input: encodedInput,
881881
ApplicationID: c.GetApplicationID(),
882882
QueueName: params.queueName,
883883
DeduplicationID: params.deduplicationID,
@@ -1037,7 +1037,7 @@ func (c *dbosContext) RunWorkflow(_ DBOSContext, fn WorkflowFunc, input any, opt
10371037
workflowID: workflowID,
10381038
status: status,
10391039
err: err,
1040-
output: &encodedOutput,
1040+
output: encodedOutput,
10411041
})
10421042
}, withRetrierLogger(c.logger))
10431043
if recordErr != nil {
@@ -1323,7 +1323,7 @@ func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc, opts ...StepOption)
13231323
stepName: stepOpts.stepName,
13241324
stepID: stepState.stepID,
13251325
err: stepError,
1326-
output: &encodedStepOutput,
1326+
output: encodedStepOutput,
13271327
}
13281328
recErr := retry(c, func() error {
13291329
return c.systemDB.recordOperationResult(uncancellableCtx, dbInput)
@@ -1349,7 +1349,7 @@ func (c *dbosContext) Send(_ DBOSContext, destinationID string, message any, top
13491349
return retry(c, func() error {
13501350
return c.systemDB.send(c, WorkflowSendInput{
13511351
DestinationID: destinationID,
1352-
Message: &encodedMessage,
1352+
Message: encodedMessage,
13531353
Topic: topic,
13541354
})
13551355
}, withRetrierLogger(c.logger))
@@ -1452,7 +1452,7 @@ func (c *dbosContext) SetEvent(_ DBOSContext, key string, message any) error {
14521452
return retry(c, func() error {
14531453
return c.systemDB.setEvent(c, WorkflowSetEventInput{
14541454
Key: key,
1455-
Message: &encodedMessage,
1455+
Message: encodedMessage,
14561456
})
14571457
}, withRetrierLogger(c.logger))
14581458
}

dbos/workflows_test.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1418,8 +1418,11 @@ func TestWorkflowDeadLetterQueue(t *testing.T) {
14181418
resultAny, err := h.GetResult()
14191419
require.NoError(t, err, "failed to get result from handle %d", i)
14201420
// Decode the result from any (which may be float64 after JSON decode) to int
1421+
// Marshal to JSON then unmarshal into the expected type
1422+
jsonBytes, err := json.Marshal(resultAny)
1423+
require.NoError(t, err, "failed to marshal result to JSON")
14211424
var result int
1422-
err = json.Unmarshal([]byte(resultAny.(string)), &result)
1425+
err = json.Unmarshal(jsonBytes, &result)
14231426
require.NoError(t, err, "failed to decode result to int")
14241427
require.Equal(t, 0, result)
14251428
}
@@ -3210,7 +3213,8 @@ func TestWorkflowTimeout(t *testing.T) {
32103213

32113214
// Wait for the workflow to complete and check the result. Should we AwaitedWorkflowCancelled
32123215
result, err := recoveredHandle.GetResult()
3213-
assert.Equal(t, "", result, "expected result to be an empty string")
3216+
// Recovery handles are of type any, so when the handle decoded the result into any, it returned a zero value of any, which is nil, not an empty string.
3217+
assert.Nil(t, result, "expected result to be nil")
32143218
// Check the error type
32153219
dbosErr, ok := err.(*DBOSError)
32163220
require.True(t, ok, "expected error to be of type *DBOSError, got %T", err)

0 commit comments

Comments
 (0)