Skip to content

Commit 033e794

Browse files
committed
fix for mocking
1 parent dd8d725 commit 033e794

File tree

2 files changed

+34
-30
lines changed

2 files changed

+34
-30
lines changed

dbos/dbos.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,16 @@ func (c *dbosContext) Value(key any) any {
193193
return c.ctx.Value(key)
194194
}
195195

196+
// getSerializer extracts the serializer from a DBOSContext.
197+
// Returns the serializer if the context is a *dbosContext with a configured serializer,
198+
// or nil if the context is not a *dbosContext (for testing/mocking scenarios).
199+
func getSerializer(ctx DBOSContext) Serializer {
200+
if dbosCtx, ok := ctx.(*dbosContext); ok {
201+
return dbosCtx.serializer
202+
}
203+
return nil
204+
}
205+
196206
// WithValue returns a copy of the DBOS context with the given key-value pair.
197207
// This is similar to context.WithValue but maintains DBOS context capabilities.
198208
// No-op if the provided context is not a concrete dbos.dbosContext.
@@ -370,8 +380,6 @@ func NewDBOSContext(ctx context.Context, inputConfig Config) (DBOSContext, error
370380
initExecutor.executorID = config.ExecutorID
371381
initExecutor.serializer = config.Serializer
372382

373-
374-
375383
initExecutor.applicationID = os.Getenv("DBOS__APPID")
376384

377385
newSystemDatabaseInputs := newSystemDatabaseInput{

dbos/workflow.go

Lines changed: 24 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ func (h *workflowPollingHandle[R]) GetResult(opts ...GetResultOption) (R, error)
269269
return *new(R), newWorkflowUnexpectedResultType(h.workflowID, "string (encoded)", fmt.Sprintf("%T", encodedResult))
270270
}
271271
var deserErr error
272-
typedResult, deserErr = deserialize[R](h.dbosContext.(*dbosContext).serializer, encodedStr)
272+
typedResult, deserErr = deserialize[R](getSerializer(h.dbosContext), encodedStr)
273273
if deserErr != nil {
274274
return *new(R), fmt.Errorf("failed to deserialize workflow result: %w", deserErr)
275275
}
@@ -530,7 +530,7 @@ func RegisterWorkflow[P any, R any](ctx DBOSContext, fn Workflow[P, R], opts ...
530530
return *new(R), newWorkflowUnexpectedInputType(fqn, "*string (encoded)", fmt.Sprintf("%T", input))
531531
}
532532
// Decode directly into the target type
533-
typedInput, err := deserialize[P](ctx.(*dbosContext).serializer, encodedInput)
533+
typedInput, err := deserialize[P](getSerializer(ctx), encodedInput)
534534
if err != nil {
535535
return *new(R), newWorkflowExecutionError(workflowID, err)
536536
}
@@ -1191,10 +1191,7 @@ func RunAsStep[R any](ctx DBOSContext, fn Step[R], opts ...StepOption) (R, error
11911191
return *new(R), newStepExecutionError("", "", fmt.Errorf("step function cannot be nil"))
11921192
}
11931193

1194-
var serializer Serializer
1195-
if c, ok := ctx.(*dbosContext); ok {
1196-
serializer = c.serializer
1197-
}
1194+
serializer := getSerializer(ctx)
11981195

11991196
// Append WithStepName option to ensure the step name is set. This will not erase a user-provided step name
12001197
stepName := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name()
@@ -1427,27 +1424,27 @@ func Recv[T any](ctx DBOSContext, topic string, timeout time.Duration) (T, error
14271424
return *new(T), nil
14281425
}
14291426

1430-
// Decode the message directly into the target type
1431-
// msg is an encoded *string from Recv
1432-
encodedMsg, ok := msg.(*string)
1433-
if !ok {
1434-
return *new(T), newWorkflowUnexpectedResultType("", "string (encoded)", fmt.Sprintf("%T", msg))
1435-
}
1436-
14371427
var typedMessage T
1438-
var serializer Serializer
1439-
if dbosCtx, ok := ctx.(*dbosContext); ok {
1440-
serializer = dbosCtx.serializer
1428+
serializer := getSerializer(ctx)
1429+
if serializer != nil {
1430+
encodedMsg, ok := msg.(*string)
1431+
if !ok {
1432+
workflowID, _ := GetWorkflowID(ctx) // Must be within a workflow so we can ignore the error
1433+
return *new(T), newWorkflowUnexpectedResultType(workflowID, "string (encoded)", fmt.Sprintf("%T", msg))
1434+
}
14411435
var decodeErr error
14421436
typedMessage, decodeErr = deserialize[T](serializer, encodedMsg)
14431437
if decodeErr != nil {
14441438
return *new(T), fmt.Errorf("decoding received message to type %T: %w", *new(T), decodeErr)
14451439
}
1440+
return typedMessage, nil
14461441
} else {
1442+
// Fallback for testing/mocking scenarios where serializer is nil
14471443
var ok bool
14481444
typedMessage, ok = msg.(T)
14491445
if !ok {
1450-
return *new(T), newWorkflowUnexpectedResultType("", fmt.Sprintf("%T", new(T)), fmt.Sprintf("%T", msg))
1446+
workflowID, _ := GetWorkflowID(ctx) // Must be within a workflow so we can ignore the error
1447+
return *new(T), newWorkflowUnexpectedResultType(workflowID, fmt.Sprintf("%T", new(T)), fmt.Sprintf("%T", msg))
14511448
}
14521449
}
14531450
return typedMessage, nil
@@ -1528,17 +1525,15 @@ func GetEvent[T any](ctx DBOSContext, targetWorkflowID, key string, timeout time
15281525
return *new(T), nil
15291526
}
15301527

1531-
// Decode the event value directly into the target type
1532-
// value is an encoded *string from GetEvent
1533-
encodedValue, ok := value.(*string)
1534-
if !ok {
1535-
return *new(T), newWorkflowUnexpectedResultType("", "string (encoded)", fmt.Sprintf("%T", value))
1536-
}
1537-
15381528
var typedValue T
1539-
var serializer Serializer
1540-
if dbosCtx, ok := ctx.(*dbosContext); ok {
1541-
serializer = dbosCtx.serializer
1529+
serializer := getSerializer(ctx)
1530+
if serializer != nil {
1531+
encodedValue, ok := value.(*string)
1532+
if !ok {
1533+
workflowID, _ := GetWorkflowID(ctx) // Must be within a workflow so we can ignore the error
1534+
return *new(T), newWorkflowUnexpectedResultType(workflowID, "string (encoded)", fmt.Sprintf("%T", value))
1535+
}
1536+
15421537
typedValue, decodeErr := deserialize[T](serializer, encodedValue)
15431538
if decodeErr != nil {
15441539
return *new(T), fmt.Errorf("decoding event value to type %T: %w", *new(T), decodeErr)
@@ -1548,7 +1543,8 @@ func GetEvent[T any](ctx DBOSContext, targetWorkflowID, key string, timeout time
15481543
var ok bool
15491544
typedValue, ok = value.(T)
15501545
if !ok {
1551-
return *new(T), newWorkflowUnexpectedResultType("", fmt.Sprintf("%T", new(T)), fmt.Sprintf("%T", value))
1546+
workflowID, _ := GetWorkflowID(ctx) // Must be within a workflow so we can ignore the error
1547+
return *new(T), newWorkflowUnexpectedResultType(workflowID, fmt.Sprintf("%T", new(T)), fmt.Sprintf("%T", value))
15521548
}
15531549
}
15541550
return typedValue, nil

0 commit comments

Comments
 (0)