Skip to content

Commit bcd86f3

Browse files
committed
cleanup
1 parent 7cc6ffb commit bcd86f3

File tree

4 files changed

+37
-94
lines changed

4 files changed

+37
-94
lines changed

dbos/serialization.go

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"encoding/base64"
55
"encoding/json"
66
"fmt"
7-
"reflect"
87
)
98

109
// Serializer defines the interface for pluggable serializers.
@@ -30,13 +29,11 @@ func isJSONSerializer(s Serializer) bool {
3029

3130
func (j *JSONSerializer) Encode(data any) (string, error) {
3231
var inputBytes []byte
33-
if !isNilValue(data) {
34-
jsonBytes, err := json.Marshal(data)
35-
if err != nil {
36-
return "", fmt.Errorf("failed to marshal data to JSON: %w", err)
37-
}
38-
inputBytes = jsonBytes
32+
jsonBytes, err := json.Marshal(data)
33+
if err != nil {
34+
return "", fmt.Errorf("failed to marshal data to JSON: %w", err)
3935
}
36+
inputBytes = jsonBytes
4037
return base64.StdEncoding.EncodeToString(inputBytes), nil
4138
}
4239

@@ -99,20 +96,3 @@ func deserialize[T any](serializer Serializer, encoded *string) (T, error) {
9996
}
10097
return typedResult, nil
10198
}
102-
103-
// Handle cases where the provided data interface wraps a nil value (e.g., var p *int; data := any(p). data != nil but the underlying value is nil)
104-
func isNilValue(data any) bool {
105-
if data == nil {
106-
return true
107-
}
108-
v := reflect.ValueOf(data)
109-
// Check if the value is invalid (zero Value from reflect)
110-
if !v.IsValid() {
111-
return true
112-
}
113-
switch v.Kind() {
114-
case reflect.Pointer, reflect.Slice, reflect.Map, reflect.Interface:
115-
return v.IsNil()
116-
}
117-
return false
118-
}

dbos/system_database.go

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -443,16 +443,6 @@ func (s *sysDB) insertWorkflowStatus(ctx context.Context, input insertWorkflowSt
443443
timeoutMs = &millis
444444
}
445445

446-
// Input is already encoded as *string from the typed layer
447-
var inputString *string
448-
if input.status.Input != nil {
449-
encodedInput, ok := input.status.Input.(*string)
450-
if !ok {
451-
return nil, fmt.Errorf("workflow input must be encoded *string, got %T", input.status.Input)
452-
}
453-
inputString = encodedInput
454-
}
455-
456446
// Our DB works with NULL values
457447
var applicationVersion *string
458448
if len(input.status.ApplicationVersion) > 0 {
@@ -524,7 +514,7 @@ func (s *sysDB) insertWorkflowStatus(ctx context.Context, input insertWorkflowSt
524514
updatedAt.UnixMilli(),
525515
timeoutMs,
526516
deadline,
527-
inputString, // encoded input (already *string)
517+
input.status.Input, // encoded input (already *string)
528518
deduplicationID,
529519
input.status.Priority,
530520
WorkflowStatusEnqueued,
@@ -1122,7 +1112,7 @@ func (s *sysDB) forkWorkflow(ctx context.Context, input forkWorkflowDBInput) (st
11221112
&appVersion,
11231113
originalWorkflow.ApplicationID,
11241114
_DBOS_INTERNAL_QUEUE_NAME,
1125-
originalWorkflow.Input, // Input is already encoded *string from listWorkflows
1115+
originalWorkflow.Input, // encoded
11261116
time.Now().UnixMilli(),
11271117
time.Now().UnixMilli(),
11281118
0)
@@ -1208,7 +1198,6 @@ func (s *sysDB) recordOperationResult(ctx context.Context, input recordOperation
12081198
errorString = &e
12091199
}
12101200

1211-
// input.output is already a *string from the database layer
12121201
var err error
12131202
if input.tx != nil {
12141203
_, err = input.tx.Exec(ctx, query,
@@ -1767,7 +1756,6 @@ func (s *sysDB) send(ctx context.Context, input WorkflowSendInput) error {
17671756
topic = input.Topic
17681757
}
17691758

1770-
// input.Message is already encoded *string from the typed layer
17711759
insertQuery := fmt.Sprintf(`INSERT INTO %s.notifications (destination_uuid, topic, message) VALUES ($1, $2, $3)`, pgx.Identifier{s.schema}.Sanitize())
17721760
_, err = tx.Exec(ctx, insertQuery, input.DestinationID, topic, input.Message)
17731761
if err != nil {

dbos/workflow.go

Lines changed: 18 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ func (ws *workflowState) nextStepID() int {
7575
type workflowOutcome[R any] struct {
7676
result R
7777
err error
78-
needsDecoding bool // true if result came from awaitWorkflowResult (ID conflict path) and needs JSON type conversion
78+
needsDecoding bool // true if result came from awaitWorkflowResult (ID conflict path) and needs decoding
7979
}
8080

8181
// WorkflowHandle provides methods to interact with a running or completed workflow.
@@ -211,22 +211,21 @@ func (h *workflowHandle[R]) processOutcome(outcome workflowOutcome[R]) (R, error
211211
isWithinWorkflow := ok && workflowState != nil
212212
if isWithinWorkflow {
213213
dbosCtx, ok := h.dbosContext.(*dbosContext)
214-
if !ok {
214+
if !ok { // Should never happen
215215
return *new(R), newWorkflowExecutionError(workflowState.workflowID, fmt.Errorf("invalid DBOSContext: expected *dbosContext"))
216216
}
217217
if dbosCtx.serializer == nil {
218218
return *new(R), newWorkflowExecutionError(workflowState.workflowID, fmt.Errorf("no serializer configured in DBOSContext"))
219219
}
220-
encodedOutputStr, encErr := dbosCtx.serializer.Encode(typedResult)
220+
encodedOutput, encErr := dbosCtx.serializer.Encode(typedResult)
221221
if encErr != nil {
222222
return *new(R), newWorkflowExecutionError(workflowState.workflowID, fmt.Errorf("serializing child workflow result: %w", encErr))
223223
}
224-
encodedOutput := &encodedOutputStr
225224
recordGetResultInput := recordChildGetResultDBInput{
226225
parentWorkflowID: workflowState.workflowID,
227226
childWorkflowID: h.workflowID,
228227
stepID: workflowState.nextStepID(),
229-
output: encodedOutput,
228+
output: &encodedOutput,
230229
err: outcome.err,
231230
}
232231
recordResultErr := retry(h.dbosContext, func() error {
@@ -279,23 +278,15 @@ func (h *workflowPollingHandle[R]) GetResult(opts ...GetResultOption) (R, error)
279278
workflowState, ok := h.dbosContext.Value(workflowStateKey).(*workflowState)
280279
isWithinWorkflow := ok && workflowState != nil
281280
if isWithinWorkflow {
282-
dbosCtx, ok := h.dbosContext.(*dbosContext)
283-
if !ok {
284-
return *new(R), newWorkflowExecutionError(workflowState.workflowID, fmt.Errorf("invalid DBOSContext: expected *dbosContext"))
285-
}
286-
if dbosCtx.serializer == nil {
287-
return *new(R), newWorkflowExecutionError(workflowState.workflowID, fmt.Errorf("no serializer configured in DBOSContext"))
288-
}
289-
encodedOutputStr, encErr := dbosCtx.serializer.Encode(typedResult)
290-
if encErr != nil {
291-
return *new(R), newWorkflowExecutionError(workflowState.workflowID, fmt.Errorf("serializing child workflow result: %w", encErr))
281+
encodedResultStr, ok := encodedResult.(*string)
282+
if !ok { // Should never happen
283+
return *new(R), newWorkflowUnexpectedResultType(h.workflowID, "string (encoded)", fmt.Sprintf("%T", encodedResult))
292284
}
293-
encodedOutput := &encodedOutputStr
294285
recordGetResultInput := recordChildGetResultDBInput{
295286
parentWorkflowID: workflowState.workflowID,
296287
childWorkflowID: h.workflowID,
297288
stepID: workflowState.nextStepID(),
298-
output: encodedOutput,
289+
output: encodedResultStr,
299290
err: err,
300291
}
301292
recordResultErr := retry(h.dbosContext, func() error {
@@ -722,7 +713,7 @@ func RunWorkflow[P any, R any](ctx DBOSContext, fn Workflow[P, R], input P, opts
722713
var typedResult R
723714

724715
// Handle nil results - nil cannot be type-asserted to any interface
725-
if isNilValue(outcome.result) {
716+
if outcome.result == nil {
726717
typedOutcomeChan <- workflowOutcome[R]{
727718
result: typedResult,
728719
err: resultErr,
@@ -732,10 +723,9 @@ func RunWorkflow[P any, R any](ctx DBOSContext, fn Workflow[P, R], input P, opts
732723

733724
// Get serializer from context
734725
dbosCtx, ok := handle.dbosContext.(*dbosContext)
735-
if !ok {
736-
resultErr = errors.Join(resultErr, fmt.Errorf("invalid DBOSContext type"))
726+
if !ok { // Likely a mocked path
737727
typedOutcomeChan <- workflowOutcome[R]{
738-
result: typedResult,
728+
result: outcome.result.(R),
739729
err: resultErr,
740730
}
741731
return
@@ -1201,7 +1191,6 @@ func RunAsStep[R any](ctx DBOSContext, fn Step[R], opts ...StepOption) (R, error
12011191
return *new(R), newStepExecutionError("", "", fmt.Errorf("step function cannot be nil"))
12021192
}
12031193

1204-
// No gob registration required
12051194
var serializer Serializer
12061195
if c, ok := ctx.(*dbosContext); ok {
12071196
serializer = c.serializer
@@ -1289,10 +1278,7 @@ func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc, opts ...StepOption)
12891278
}
12901279
if recordedOutput != nil {
12911280
// Return the encoded output - decoding will happen in RunAsStep[R] when we know the target type
1292-
if recordedOutput.output != nil {
1293-
return recordedOutput.output, recordedOutput.err
1294-
}
1295-
return nil, recordedOutput.err
1281+
return recordedOutput.output, recordedOutput.err
12961282
}
12971283

12981284
// Spawn a child DBOSContext with the step state
@@ -1372,16 +1358,15 @@ func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc, opts ...StepOption)
13721358

13731359
func (c *dbosContext) Send(_ DBOSContext, destinationID string, message any, topic string) error {
13741360
// Serialize the message before sending
1375-
encodedMessageStr, err := c.serializer.Encode(message)
1361+
encodedMessage, err := c.serializer.Encode(message)
13761362
if err != nil {
13771363
return fmt.Errorf("failed to serialize message: %w", err)
13781364
}
1379-
encodedMessage := &encodedMessageStr
13801365

13811366
return retry(c, func() error {
13821367
return c.systemDB.send(c, WorkflowSendInput{
13831368
DestinationID: destinationID,
1384-
Message: encodedMessage,
1369+
Message: &encodedMessage,
13851370
Topic: topic,
13861371
})
13871372
}, withRetrierLogger(c.logger))
@@ -1413,18 +1398,9 @@ func (c *dbosContext) Recv(_ DBOSContext, topic string, timeout time.Duration) (
14131398
Topic: topic,
14141399
Timeout: timeout,
14151400
}
1416-
encodedMsg, err := retryWithResult(c, func() (*string, error) {
1401+
return retryWithResult(c, func() (*string, error) {
14171402
return c.systemDB.recv(c, input)
14181403
}, withRetrierLogger(c.logger))
1419-
if err != nil {
1420-
return nil, err
1421-
}
1422-
1423-
if encodedMsg == nil {
1424-
return nil, nil
1425-
}
1426-
// Return encoded string - decoding will happen in Recv[T] when we know the target type
1427-
return encodedMsg, nil
14281404
}
14291405

14301406
// Recv receives a message sent to this workflow with type safety.
@@ -1483,16 +1459,15 @@ func Recv[T any](ctx DBOSContext, topic string, timeout time.Duration) (T, error
14831459

14841460
func (c *dbosContext) SetEvent(_ DBOSContext, key string, message any) error {
14851461
// Serialize the event value before storing
1486-
encodedMessageStr, err := c.serializer.Encode(message)
1462+
encodedMessage, err := c.serializer.Encode(message)
14871463
if err != nil {
14881464
return fmt.Errorf("failed to serialize event value: %w", err)
14891465
}
1490-
encodedMessage := &encodedMessageStr
14911466

14921467
return retry(c, func() error {
14931468
return c.systemDB.setEvent(c, WorkflowSetEventInput{
14941469
Key: key,
1495-
Message: encodedMessage,
1470+
Message: &encodedMessage,
14961471
})
14971472
}, withRetrierLogger(c.logger))
14981473
}
@@ -1511,7 +1486,6 @@ func SetEvent[P any](ctx DBOSContext, key string, message P) error {
15111486
if ctx == nil {
15121487
return errors.New("ctx cannot be nil")
15131488
}
1514-
// No gob registration required
15151489
return ctx.SetEvent(ctx, key, message)
15161490
}
15171491

@@ -1527,15 +1501,9 @@ func (c *dbosContext) GetEvent(_ DBOSContext, targetWorkflowID, key string, time
15271501
Key: key,
15281502
Timeout: timeout,
15291503
}
1530-
encodedValue, err := retryWithResult(c, func() (any, error) {
1504+
return retryWithResult(c, func() (any, error) {
15311505
return c.systemDB.getEvent(c, input)
15321506
}, withRetrierLogger(c.logger))
1533-
if err != nil {
1534-
return nil, err
1535-
}
1536-
1537-
// Return encoded string - decoding will happen in GetEvent[T] when we know the target type
1538-
return encodedValue, nil
15391507
}
15401508

15411509
// GetEvent retrieves a key-value event from a target workflow with type safety.

dbos/workflows_test.go

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1110,10 +1110,11 @@ func TestWorkflowRecovery(t *testing.T) {
11101110
dbosCtx := setupDBOS(t, true, true, nil)
11111111

11121112
var (
1113-
recoveryCounters []int64
1114-
recoveryEvents []*Event
1115-
blockingEvents []*Event
1116-
secondStepErrors []error
1113+
recoveryCounters []int64
1114+
recoveryEvents []*Event
1115+
blockingEvents []*Event
1116+
secondStepErrors []error
1117+
secondStepErrorsMu sync.Mutex
11171118
)
11181119

11191120
recoveryWorkflow := func(dbosCtx DBOSContext, index int) (int64, error) {
@@ -1135,7 +1136,9 @@ func TestWorkflowRecovery(t *testing.T) {
11351136
return fmt.Sprintf("completed-%d", index), nil
11361137
}, WithStepName(fmt.Sprintf("BlockingStep-%d", index)))
11371138
if err != nil {
1139+
secondStepErrorsMu.Lock()
11381140
secondStepErrors = append(secondStepErrors, err)
1141+
secondStepErrorsMu.Unlock()
11391142
return 0, err
11401143
}
11411144

@@ -1259,8 +1262,12 @@ func TestWorkflowRecovery(t *testing.T) {
12591262

12601263
// At least 5 of the 2nd steps should have errored due to execution race
12611264
// Check they are DBOSErrors with StepExecutionError wrapping a ConflictingIDError
1262-
require.GreaterOrEqual(t, len(secondStepErrors), 5, "expected at least 5 errors from second steps due to recovery race, got %d", len(secondStepErrors))
1263-
for _, err := range secondStepErrors {
1265+
secondStepErrorsMu.Lock()
1266+
errorsCopy := make([]error, len(secondStepErrors))
1267+
copy(errorsCopy, secondStepErrors)
1268+
secondStepErrorsMu.Unlock()
1269+
require.GreaterOrEqual(t, len(errorsCopy), 5, "expected at least 5 errors from second steps due to recovery race, got %d", len(errorsCopy))
1270+
for _, err := range errorsCopy {
12641271
dbosErr, ok := err.(*DBOSError)
12651272
require.True(t, ok, "expected error to be of type *DBOSError, got %T", err)
12661273
require.Equal(t, StepExecutionError, dbosErr.Code, "expected error code to be StepExecutionError, got %v", dbosErr.Code)

0 commit comments

Comments
 (0)