Skip to content

Commit a2b0872

Browse files
committed
cleanup
1 parent 38c4ff7 commit a2b0872

File tree

8 files changed

+246
-235
lines changed

8 files changed

+246
-235
lines changed

dbos/client.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ func (c *client) Enqueue(queueName, workflowName string, input any, opts ...Enqu
149149
}
150150

151151
// Serialize input before storing in workflow status
152-
encodedInputStr, err := serialize(dbosCtx, params.workflowInput)
152+
encodedInputStr, err := serialize(dbosCtx.serializer, params.workflowInput)
153153
if err != nil {
154154
return nil, fmt.Errorf("failed to serialize workflow input: %w", err)
155155
}
@@ -247,7 +247,6 @@ func Enqueue[P any, R any](c Client, queueName, workflowName string, input P, op
247247
if c == nil {
248248
return nil, errors.New("client cannot be nil")
249249
}
250-
251250

252251
// Call the interface method with the same signature
253252
handle, err := c.Enqueue(queueName, workflowName, input, opts...)

dbos/queue.go

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -224,13 +224,8 @@ func (qr *queueRunner) run(ctx *dbosContext) {
224224
continue
225225
}
226226

227-
input, err := ctx.serializer.Decode(workflow.input)
228-
if err != nil {
229-
qr.logger.Error("Failed to decode workflow input", "workflow_id", workflow.id, "error", err)
230-
continue
231-
}
232-
233-
_, err = registeredWorkflow.wrappedFunction(ctx, input, WithWorkflowID(workflow.id))
227+
// Pass encoded input directly - decoding will happen in workflow wrapper when we know the target type
228+
_, err = registeredWorkflow.wrappedFunction(ctx, workflow.input, WithWorkflowID(workflow.id))
234229
if err != nil {
235230
qr.logger.Error("Error running queued workflow", "error", err)
236231
}

dbos/recovery.go

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,15 @@ func recoverPendingWorkflows(ctx *dbosContext, executorIDs []string) ([]Workflow
1414
}
1515

1616
for _, workflow := range pendingWorkflows {
17-
// Deserialize the workflow input
18-
var decodedInput any
17+
// Pass encoded input directly - decoding will happen in workflow wrapper when we know the target type
18+
var encodedInput *string
1919
if workflow.Input != nil {
2020
inputString, ok := workflow.Input.(*string)
2121
if !ok {
2222
ctx.logger.Warn("Skipping workflow recovery due to invalid input type", "workflow_id", workflow.ID, "name", workflow.Name, "input_type", workflow.Input)
2323
continue
2424
}
25-
decodedInput, err = ctx.serializer.Decode(inputString)
26-
if err != nil {
27-
ctx.logger.Warn("Skipping workflow recovery due to input decoding failure", "workflow_id", workflow.ID, "name", workflow.Name, "error", err)
28-
continue
29-
}
25+
encodedInput = inputString
3026
}
3127

3228
if workflow.QueueName != "" {
@@ -63,7 +59,7 @@ func recoverPendingWorkflows(ctx *dbosContext, executorIDs []string) ([]Workflow
6359
WithWorkflowID(workflow.ID),
6460
}
6561
// Create a workflow context from the executor context
66-
handle, err := registeredWorkflow.wrappedFunction(ctx, decodedInput, opts...)
62+
handle, err := registeredWorkflow.wrappedFunction(ctx, encodedInput, opts...)
6763
if err != nil {
6864
return nil, err
6965
}

dbos/serialization.go

Lines changed: 38 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@ func NewJSONSerializer() *JSONSerializer {
2323
return &JSONSerializer{}
2424
}
2525

26+
func isJSONSerializer(s Serializer) bool {
27+
_, ok := s.(*JSONSerializer)
28+
return ok
29+
}
30+
2631
func (j *JSONSerializer) Encode(data any) (string, error) {
2732
var inputBytes []byte
2833
if !isNilValue(data) {
@@ -53,69 +58,53 @@ func (j *JSONSerializer) Decode(data *string) (any, error) {
5358
return result, nil
5459
}
5560

56-
// serialize serializes data using the serializer from the DBOSContext
57-
// convenience helper to also check the context & serializer
58-
func serialize(ctx DBOSContext, data any) (string, error) {
59-
dbosCtx, ok := ctx.(*dbosContext)
60-
if !ok {
61-
return "", fmt.Errorf("invalid DBOSContext: expected *dbosContext")
61+
// serialize serializes data using the provided serializer
62+
func serialize[T any](serializer Serializer, data T) (string, error) {
63+
if serializer == nil {
64+
return "", fmt.Errorf("serializer cannot be nil")
6265
}
63-
if dbosCtx.serializer == nil {
64-
return "", fmt.Errorf("no serializer configured in DBOSContext")
65-
}
66-
return dbosCtx.serializer.Encode(data)
67-
}
68-
69-
func isJSONSerializer(s Serializer) bool {
70-
_, ok := s.(*JSONSerializer)
71-
return ok
66+
return serializer.Encode(data)
7267
}
7368

74-
// convertJSONToType converts a JSON-decoded value (map[string]interface{}) to type T
75-
// via marshal/unmarshal round-trip.
76-
//
77-
// This is needed because JSON deserialization loses type information when decoding
78-
// into `any` - it converts structs to map[string]interface{}, numbers to float64, etc.
79-
// By re-marshaling and unmarshaling into a typed target, we (mostly) restore the original structure.
80-
// We should be able to get rid of this when we lift encoding/decoding outside of the system database.
81-
func convertJSONToType[T any](value any) (T, error) {
82-
if value == nil {
83-
return *new(T), nil
84-
}
85-
86-
jsonBytes, err := json.Marshal(value)
87-
if err != nil {
88-
return *new(T), fmt.Errorf("marshaling for type conversion: %w", err)
69+
// deserialize decodes an encoded string directly into a typed variable.
70+
// For JSON serializer, this decodes directly into the target type, preserving type information.
71+
// For other serializers, it decodes into any and then type-asserts.
72+
// (we don't want generic Serializer interface because it would require 1 serializer per type)
73+
func deserialize[T any](serializer Serializer, encoded *string) (T, error) {
74+
if serializer == nil {
75+
return *new(T), fmt.Errorf("serializer cannot be nil")
8976
}
9077

91-
// Check if T is an interface type
9278
var zero T
93-
typeOfT := reflect.TypeOf(&zero).Elem()
79+
if encoded == nil || *encoded == "" {
80+
return zero, nil
81+
}
9482

95-
if typeOfT.Kind() == reflect.Interface {
96-
// T is interface - need to get concrete type from value
97-
concreteType := reflect.TypeOf(value)
98-
if concreteType.Kind() == reflect.Pointer {
99-
concreteType = concreteType.Elem()
83+
if isJSONSerializer(serializer) {
84+
// For JSON serializer, decode directly into the target type to preserve type information
85+
// We cannot just use the serializer's Decode method and recast -- the type inormation would be lost
86+
dataBytes, err := base64.StdEncoding.DecodeString(*encoded)
87+
if err != nil {
88+
return zero, fmt.Errorf("failed to decode base64 data: %w", err)
10089
}
10190

102-
// Create new instance of concrete type
103-
newInstance := reflect.New(concreteType)
91+
// We could check and error explicitly if T is an interface type.
10492

105-
// Unmarshal into the concrete type
106-
if err := json.Unmarshal(jsonBytes, newInstance.Interface()); err != nil {
107-
return *new(T), fmt.Errorf("unmarshaling for type conversion: %w", err)
93+
if err := json.Unmarshal(dataBytes, &zero); err != nil {
94+
return zero, fmt.Errorf("failed to unmarshal JSON data: %w", err)
10895
}
109-
110-
// Convert to interface type T
111-
return newInstance.Elem().Interface().(T), nil
96+
return zero, nil
11297
}
11398

114-
var typedResult T
115-
if err := json.Unmarshal(jsonBytes, &typedResult); err != nil {
116-
return *new(T), fmt.Errorf("unmarshaling for type conversion: %w", err)
99+
// For other serializers, just call the decoder and type-assert
100+
decoded, err := serializer.Decode(encoded)
101+
if err != nil {
102+
return zero, err
103+
}
104+
typedResult, ok := decoded.(T)
105+
if !ok {
106+
return zero, fmt.Errorf("cannot convert decoded value of type %T to %T", decoded, zero)
117107
}
118-
119108
return typedResult, nil
120109
}
121110

dbos/serialization_test.go

Lines changed: 48 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1042,18 +1042,15 @@ func TestSerializer(t *testing.T) {
10421042
isJSON := isJSONSerializer(dbosCtx.serializer)
10431043

10441044
if isJSON {
1045-
// JSON serializer returns map[string]any
1046-
inputMap, ok := workflow.Input.(map[string]any)
1047-
require.True(t, ok, "Input should be map[string]any for JSON")
1048-
assert.Equal(t, input.Message, inputMap["Message"], "Message should match in input")
1049-
assert.Equal(t, float64(input.Value), inputMap["Value"], "Value should match in input")
1050-
1051-
outputMap, ok := workflow.Output.(map[string]any)
1052-
require.True(t, ok, "Output should be map[string]any for JSON")
1053-
assert.Equal(t, input.Message, outputMap["Message"], "Message should match in output")
1054-
assert.Equal(t, float64(input.Value), outputMap["Value"], "Value should match in output")
1045+
// JSON serializer returns map[string]any. We need to convert it to a _concrete_ type
1046+
inputConcrete, err := convertJSONToType[ConcreteDataProvider](workflow.Input)
1047+
require.NoError(t, err, "Failed to convert workflow input to ConcreteDataProvider")
1048+
assert.Equal(t, input, inputConcrete, "Workflow input should match input")
1049+
1050+
outputConcrete, err := convertJSONToType[ConcreteDataProvider](workflow.Output)
1051+
require.NoError(t, err, "Failed to convert workflow output to ConcreteDataProvider")
1052+
assert.Equal(t, input, outputConcrete, "Workflow output should match input")
10551053
} else {
1056-
// Gob serializer preserves the concrete type
10571054
inputConcrete, ok := workflow.Input.(ConcreteDataProvider)
10581055
require.True(t, ok, "Input should be ConcreteDataProvider for Gob")
10591056
assert.Equal(t, input, inputConcrete, "Input should match")
@@ -1062,6 +1059,46 @@ func TestSerializer(t *testing.T) {
10621059
require.True(t, ok, "Output should be ConcreteDataProvider for Gob")
10631060
assert.Equal(t, input, outputConcrete, "Output should match")
10641061
}
1062+
1063+
// Test GetWorkflowSteps for interface types
1064+
t.Run("GetWorkflowSteps", func(t *testing.T) {
1065+
steps, err := GetWorkflowSteps(executor, handle.GetWorkflowID())
1066+
require.NoError(t, err, "Failed to get workflow steps")
1067+
require.Len(t, steps, 1, "Expected 1 step")
1068+
1069+
step := steps[0]
1070+
require.NotNil(t, step.Output, "Step output should not be nil")
1071+
assert.Nil(t, step.Error, "Step should not have error")
1072+
1073+
if isJSON {
1074+
// JSON serializer returns map[string]any, convert to ConcreteDataProvider
1075+
outputConcrete, err := convertJSONToType[ConcreteDataProvider](step.Output)
1076+
require.NoError(t, err, "Failed to convert step output to ConcreteDataProvider")
1077+
assert.Equal(t, input, outputConcrete, "Step output should match input")
1078+
} else {
1079+
outputConcrete, ok := step.Output.(ConcreteDataProvider)
1080+
require.True(t, ok, "Output should be ConcreteDataProvider for Gob")
1081+
assert.Equal(t, input, outputConcrete, "Step output should match input")
1082+
}
1083+
})
1084+
1085+
// Test RetrieveWorkflow for interface types
1086+
// TODO: not supported for interface types w/o storing the type information in the DB
1087+
/*
1088+
t.Run("RetrieveWorkflow", func(t *testing.T) {
1089+
h2, err := RetrieveWorkflow[DataProvider](executor, handle.GetWorkflowID())
1090+
require.NoError(t, err, "Failed to retrieve workflow")
1091+
1092+
retrievedResult, err := h2.GetResult()
1093+
require.NoError(t, err, "Failed to get retrieved workflow result")
1094+
1095+
// For interface types, we need to check the concrete type
1096+
concreteRetrievedResult, ok := retrievedResult.(ConcreteDataProvider)
1097+
require.True(t, ok, "Retrieved result should be ConcreteDataProvider type")
1098+
assert.Equal(t, input.Message, concreteRetrievedResult.Message, "Message should match")
1099+
assert.Equal(t, input.Value, concreteRetrievedResult.Value, "Value should match")
1100+
})
1101+
*/
10651102
})
10661103

10671104
// Test nil values with pointer type workflow

dbos/system_database.go

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1541,27 +1541,11 @@ func (s *sysDB) sleep(ctx context.Context, input sleepInput) (time.Duration, err
15411541
return 0, fmt.Errorf("no recorded end time for recorded sleep operation")
15421542
}
15431543

1544-
// Deserialize the recorded end time
1545-
decodedOutput, err := s.serializer.Decode(recordedResult.output)
1544+
// Decode the recorded end time directly into time.Time
1545+
// recordedResult.output is an encoded *string
1546+
endTime, err = deserialize[time.Time](s.serializer, recordedResult.output)
15461547
if err != nil {
1547-
return 0, fmt.Errorf("failed to deserialize sleep end time: %w", err)
1548-
}
1549-
1550-
// The output should be a time.Time representing the end time
1551-
// Because checkOperationExecution returns encoded string, we need to decode it into a time.Time
1552-
endTimeInterface, ok := decodedOutput.(time.Time)
1553-
if !ok {
1554-
// JSON serializer loses type information - convert using convertJSONToType
1555-
if isJSONSerializer(s.serializer) {
1556-
endTime, err = convertJSONToType[time.Time](decodedOutput)
1557-
if err != nil {
1558-
return 0, fmt.Errorf("failed to convert recorded output to time.Time: %w", err)
1559-
}
1560-
} else {
1561-
return 0, fmt.Errorf("decoded output is not a time.Time: %T", decodedOutput)
1562-
}
1563-
} else {
1564-
endTime = endTimeInterface
1548+
return 0, fmt.Errorf("failed to decode sleep end time: %w", err)
15651549
}
15661550

15671551
if recordedResult.err != nil { // This should never happen

dbos/utils_test.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package dbos
22

33
import (
44
"context"
5+
"encoding/json"
56
"fmt"
67
"net/url"
78
"os"
@@ -181,3 +182,22 @@ func checkWfStatus(ctx DBOSContext, expectedStatus WorkflowStatusType) (bool, er
181182
}
182183
return false, nil
183184
}
185+
186+
// Re-encode the value as JSON and then unmarshal into the target type
187+
func convertJSONToType[T any](value any) (T, error) {
188+
if value == nil {
189+
return *new(T), nil
190+
}
191+
192+
jsonBytes, err := json.Marshal(value)
193+
if err != nil {
194+
return *new(T), fmt.Errorf("marshaling for type conversion: %w", err)
195+
}
196+
197+
var typedResult T
198+
if err := json.Unmarshal(jsonBytes, &typedResult); err != nil {
199+
return *new(T), fmt.Errorf("unmarshaling for type conversion: %w", err)
200+
}
201+
202+
return typedResult, nil
203+
}

0 commit comments

Comments
 (0)