From 8d5a2ca36f487e526b5550b97eb7d25d47b3978e Mon Sep 17 00:00:00 2001 From: maxdml Date: Mon, 28 Jul 2025 22:27:16 -0700 Subject: [PATCH 01/30] wip --- dbos/dbos.go | 116 ++++++++++++++++++++++++++++----------------- dbos/dbos_test.go | 4 +- dbos/queue.go | 2 +- dbos/recovery.go | 2 +- dbos/utils_test.go | 15 ++++-- dbos/workflow.go | 91 ++++++++++++++++++----------------- 6 files changed, 132 insertions(+), 98 deletions(-) diff --git a/dbos/dbos.go b/dbos/dbos.go index 60bd8c44..acd10f40 100644 --- a/dbos/dbos.go +++ b/dbos/dbos.go @@ -19,8 +19,6 @@ var ( _DEFAULT_ADMIN_SERVER_PORT = 3001 ) -var workflowScheduler *cron.Cron // Global because accessed during workflow registration before the dbos singleton is initialized - var logger *slog.Logger // Global because accessed everywhere inside the library func getLogger() *slog.Logger { @@ -62,25 +60,56 @@ func processConfig(inputConfig *Config) (*Config, error) { return dbosConfig, nil } +type DBOSExecutor interface { + Launch() error + Shutdown() + + RegisterWorkflow(fqn string, fn typedErasedWorkflowWrapperFunc, maxRetries int) + RegisterScheduledWorkflow(fqn string, fn typedErasedWorkflowWrapperFunc, cronSchedule string, maxRetries int) + + GetWorkflowScheduler() *cron.Cron + GetApplicationVersion() string +} + var dbos *executor // DBOS singleton instance type executor struct { - systemDB SystemDatabase + systemDB SystemDatabase + adminServer *adminServer + config *Config + // Queue runner context and cancel function queueRunnerCtx context.Context queueRunnerCancelFunc context.CancelFunc queueRunnerDone chan struct{} - adminServer *adminServer - config *Config - applicationVersion string - applicationID string - executorID string - workflowsWg *sync.WaitGroup + // Application metadata + applicationVersion string + applicationID string + executorID string + // Wait group for workflow goroutines + workflowsWg *sync.WaitGroup + // Workflow registry + workflowRegistry map[string]workflowRegistryEntry + workflowRegMutex sync.RWMutex + // Workflow scheduler + workflowScheduler *cron.Cron } -func Initialize(inputConfig Config) error { +func (e *executor) GetWorkflowScheduler() *cron.Cron { + if e.workflowScheduler == nil { + e.workflowScheduler = cron.New(cron.WithSeconds()) + } + return e.workflowScheduler +} + +func (e *executor) GetApplicationVersion() string { + return e.applicationVersion +} + +// TODO: use a normal builder pattern name (NewDBOSExecutor) +func Initialize(inputConfig Config) (DBOSExecutor, error) { if dbos != nil { fmt.Println("warning: DBOS instance already initialized, skipping re-initialization") - return newInitializationError("DBOS already initialized") + return nil, newInitializationError("DBOS already initialized") } initExecutor := &executor{ @@ -90,7 +119,7 @@ func Initialize(inputConfig Config) error { // Load & process the configuration config, err := processConfig(&inputConfig) if err != nil { - return newInitializationError(err.Error()) + return nil, newInitializationError(err.Error()) } initExecutor.config = config @@ -119,26 +148,26 @@ func Initialize(inputConfig Config) error { // Create the system database systemDB, err := NewSystemDatabase(config.DatabaseURL) if err != nil { - return newInitializationError(fmt.Sprintf("failed to create system database: %v", err)) + return nil, newInitializationError(fmt.Sprintf("failed to create system database: %v", err)) } initExecutor.systemDB = systemDB logger.Info("System database initialized") + // Initialize the workflow registry + initExecutor.workflowRegistry = make(map[string]workflowRegistryEntry) + // Set the global dbos instance dbos = initExecutor - return nil + return initExecutor, nil } -func Launch() error { - if dbos == nil { - return newInitializationError("DBOS instance not initialized, call Initialize first") - } +func (e *executor) Launch() error { // Start the system database - dbos.systemDB.Launch(context.Background()) + e.systemDB.Launch(context.Background()) // Start the admin server if configured - if dbos.config.AdminServer { + if e.config.AdminServer { adminServer := newAdminServer(_DEFAULT_ADMIN_SERVER_PORT) err := adminServer.Start() if err != nil { @@ -151,25 +180,25 @@ func Launch() error { // Create context with cancel function for queue runner ctx, cancel := context.WithCancel(context.Background()) - dbos.queueRunnerCtx = ctx - dbos.queueRunnerCancelFunc = cancel - dbos.queueRunnerDone = make(chan struct{}) + e.queueRunnerCtx = ctx + e.queueRunnerCancelFunc = cancel + e.queueRunnerDone = make(chan struct{}) // Start the queue runner in a goroutine go func() { - defer close(dbos.queueRunnerDone) + defer close(e.queueRunnerDone) queueRunner(ctx) }() logger.Info("Queue runner started") // Start the workflow scheduler if it has been initialized - if workflowScheduler != nil { - workflowScheduler.Start() + if e.workflowScheduler != nil { + e.workflowScheduler.Start() logger.Info("Workflow scheduler started") } // Run a round of recovery on the local executor - recoveryHandles, err := recoverPendingWorkflows(context.Background(), []string{dbos.executorID}) // XXX maybe use the queue runner context here to allow Shutdown to cancel it? + recoveryHandles, err := recoverPendingWorkflows(context.Background(), []string{e.executorID}) // XXX maybe use the queue runner context here to allow Shutdown to cancel it? if err != nil { return newInitializationError(fmt.Sprintf("failed to recover pending workflows during launch: %v", err)) } @@ -177,29 +206,29 @@ func Launch() error { logger.Info("Recovered pending workflows", "count", len(recoveryHandles)) } - logger.Info("DBOS initialized", "app_version", dbos.applicationVersion, "executor_id", dbos.executorID) + logger.Info("DBOS initialized", "app_version", e.applicationVersion, "executor_id", e.executorID) return nil } -func Shutdown() { - if dbos == nil { +func (e *executor) Shutdown() { + if e == nil { fmt.Println("DBOS instance is nil, cannot shutdown") return } // XXX is there a way to ensure all workflows goroutine are done before closing? - dbos.workflowsWg.Wait() + e.workflowsWg.Wait() // Cancel the context to stop the queue runner - if dbos.queueRunnerCancelFunc != nil { - dbos.queueRunnerCancelFunc() + if e.queueRunnerCancelFunc != nil { + e.queueRunnerCancelFunc() // Wait for queue runner to finish - <-dbos.queueRunnerDone + <-e.queueRunnerDone getLogger().Info("Queue runner stopped") } - if workflowScheduler != nil { - ctx := workflowScheduler.Stop() + if e.workflowScheduler != nil { + ctx := e.workflowScheduler.Stop() // Wait for all running jobs to complete with 5-second timeout timeoutCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() @@ -212,25 +241,26 @@ func Shutdown() { } } - if dbos.systemDB != nil { - dbos.systemDB.Shutdown() - dbos.systemDB = nil + if e.systemDB != nil { + e.systemDB.Shutdown() + e.systemDB = nil } - if dbos.adminServer != nil { - err := dbos.adminServer.Shutdown() + if e.adminServer != nil { + err := e.adminServer.Shutdown() if err != nil { getLogger().Error("Failed to shutdown admin server", "error", err) } else { getLogger().Info("Admin server shutdown complete") } - dbos.adminServer = nil + e.adminServer = nil } if logger != nil { logger = nil } - dbos = nil + // XX now responsiblity of the caller right? + e = nil } func GetBinaryHash() (string, error) { diff --git a/dbos/dbos_test.go b/dbos/dbos_test.go index 192b4c8c..4e095c6c 100644 --- a/dbos/dbos_test.go +++ b/dbos/dbos_test.go @@ -12,7 +12,7 @@ func TestConfigValidationErrorTypes(t *testing.T) { DatabaseURL: databaseURL, } - err := Initialize(config) + _, err := Initialize(config) if err == nil { t.Fatal("expected error when app name is missing, but got none") } @@ -37,7 +37,7 @@ func TestConfigValidationErrorTypes(t *testing.T) { AppName: "test-app", } - err := Initialize(config) + _, err := Initialize(config) if err == nil { t.Fatal("expected error when database URL is missing, but got none") } diff --git a/dbos/queue.go b/dbos/queue.go index 6cc6ec05..ecb48ce8 100644 --- a/dbos/queue.go +++ b/dbos/queue.go @@ -143,7 +143,7 @@ func queueRunner(ctx context.Context) { } for _, workflow := range dequeuedWorkflows { // Find the workflow in the registry - registeredWorkflow, exists := registry[workflow.name] + registeredWorkflow, exists := dbos.workflowRegistry[workflow.name] if !exists { getLogger().Error("workflow function not found in registry", "workflow_name", workflow.name) continue diff --git a/dbos/recovery.go b/dbos/recovery.go index 3dc4aeb8..07792a09 100644 --- a/dbos/recovery.go +++ b/dbos/recovery.go @@ -38,7 +38,7 @@ func recoverPendingWorkflows(ctx context.Context, executorIDs []string) ([]Workf continue } - registeredWorkflow, exists := registry[workflow.Name] + registeredWorkflow, exists := dbos.workflowRegistry[workflow.Name] if !exists { getLogger().Error("Workflow function not found in registry", "workflow_id", workflow.ID, "name", workflow.Name) continue diff --git a/dbos/utils_test.go b/dbos/utils_test.go index d2f04731..609cf9ba 100644 --- a/dbos/utils_test.go +++ b/dbos/utils_test.go @@ -25,7 +25,7 @@ func getDatabaseURL(t *testing.T) string { } /* Test database setup */ -func setupDBOS(t *testing.T) { +func setupDBOS(t *testing.T) DBOSExecutor { t.Helper() databaseURL := getDatabaseURL(t) @@ -54,7 +54,7 @@ func setupDBOS(t *testing.T) { t.Fatalf("failed to drop test database: %v", err) } - err = Initialize(Config{ + executor, err := Initialize(Config{ DatabaseURL: databaseURL, AppName: "test-app", }) @@ -62,20 +62,25 @@ func setupDBOS(t *testing.T) { t.Fatalf("failed to create DBOS instance: %v", err) } - err = Launch() + err = executor.Launch() if err != nil { t.Fatalf("failed to launch DBOS instance: %v", err) } - if dbos == nil { + if executor == nil { t.Fatal("expected DBOS instance but got nil") } // Register cleanup to run after test completes t.Cleanup(func() { fmt.Println("Cleaning up DBOS instance...") - Shutdown() + if dbos != nil { + dbos.Shutdown() + dbos = nil + } }) + + return executor } /* Event struct provides a simple synchronization primitive that can be used to signal between goroutines. */ diff --git a/dbos/workflow.go b/dbos/workflow.go index 467d6ab8..ec1424aa 100644 --- a/dbos/workflow.go +++ b/dbos/workflow.go @@ -8,7 +8,6 @@ import ( "math" "reflect" "runtime" - "sync" "time" "github.com/google/uuid" @@ -207,25 +206,46 @@ type workflowRegistryEntry struct { maxRetries int } -var registry = make(map[string]workflowRegistryEntry) -var regMutex sync.RWMutex - // Register adds a workflow function to the registry (thread-safe, only once per name) -func registerWorkflow(fqn string, fn typedErasedWorkflowWrapperFunc, maxRetries int) { - regMutex.Lock() - defer regMutex.Unlock() +func (e *executor) RegisterWorkflow(fqn string, fn typedErasedWorkflowWrapperFunc, maxRetries int) { + e.workflowRegMutex.Lock() + defer e.workflowRegMutex.Unlock() - if _, exists := registry[fqn]; exists { + if _, exists := e.workflowRegistry[fqn]; exists { getLogger().Error("workflow function already registered", "fqn", fqn) panic(newConflictingRegistrationError(fqn)) } - registry[fqn] = workflowRegistryEntry{ + e.workflowRegistry[fqn] = workflowRegistryEntry{ wrappedFunction: fn, maxRetries: maxRetries, } } +func (e *executor) RegisterScheduledWorkflow(fqn string, fn typedErasedWorkflowWrapperFunc, cronSchedule string, maxRetries int) { + e.GetWorkflowScheduler().Start() + var entryID cron.EntryID + entryID, err := e.GetWorkflowScheduler().AddFunc(cronSchedule, func() { + // Execute the workflow on the cron schedule once DBOS is launched + if e == nil { + return + } + // Get the scheduled time from the cron entry + entry := e.GetWorkflowScheduler().Entry(entryID) + scheduledTime := entry.Prev + if scheduledTime.IsZero() { + // Use Next if Prev is not set, which will only happen for the first run + scheduledTime = entry.Next + } + wfID := fmt.Sprintf("sched-%s-%s", fqn, scheduledTime) // XXX we can rethink the format + fn(context.Background(), scheduledTime, WithWorkflowID(wfID), WithQueue(_DBOS_INTERNAL_QUEUE_NAME)) + }) + if err != nil { + panic(fmt.Sprintf("failed to register scheduled workflow: %v", err)) + } + getLogger().Info("Registered scheduled workflow", "fqn", fqn, "cron_schedule", cronSchedule) +} + type workflowRegistrationParams struct { cronSchedule string maxRetries int @@ -250,9 +270,10 @@ func WithSchedule(schedule string) workflowRegistrationOption { } } -func WithWorkflow[P any, R any](fn WorkflowFunc[P, R], opts ...workflowRegistrationOption) WorkflowWrapperFunc[P, R] { - if dbos != nil { - getLogger().Warn("WithWorkflow called after DBOS initialization, dynamic registration is not supported") +func WithWorkflow[P any, R any](dbosExecutor DBOSExecutor, fn WorkflowFunc[P, R], opts ...workflowRegistrationOption) WorkflowWrapperFunc[P, R] { + if dbosExecutor == nil { + // TODO: consider panic here + getLogger().Error("Provide a valid DBOS executor instance") return nil } @@ -277,39 +298,9 @@ func WithWorkflow[P any, R any](fn WorkflowFunc[P, R], opts ...workflowRegistrat // Wrap the function in a durable workflow wrappedFunction := WorkflowWrapperFunc[P, R](func(ctx context.Context, workflowInput P, opts ...workflowOption) (WorkflowHandle[R], error) { opts = append(opts, WithWorkflowMaxRetries(registrationParams.maxRetries)) - return runAsWorkflow(ctx, fn, workflowInput, opts...) + return runAsWorkflow(ctx, dbosExecutor, fn, workflowInput, opts...) }) - // If this is a scheduled workflow, register a cron job - if registrationParams.cronSchedule != "" { - if reflect.TypeOf(p) != reflect.TypeOf(time.Time{}) { - panic(fmt.Sprintf("scheduled workflow function must accept ScheduledWorkflowInput as input, got %T", p)) - } - if workflowScheduler == nil { - workflowScheduler = cron.New(cron.WithSeconds()) - } - var entryID cron.EntryID - entryID, err := workflowScheduler.AddFunc(registrationParams.cronSchedule, func() { - // Execute the workflow on the cron schedule once DBOS is launched - if dbos == nil { - return - } - // Get the scheduled time from the cron entry - entry := workflowScheduler.Entry(entryID) - scheduledTime := entry.Prev - if scheduledTime.IsZero() { - // Use Next if Prev is not set, which will only happen for the first run - scheduledTime = entry.Next - } - wfID := fmt.Sprintf("sched-%s-%s", fqn, scheduledTime) // XXX we can rethink the format - wrappedFunction(context.Background(), any(scheduledTime).(P), WithWorkflowID(wfID), WithQueue(_DBOS_INTERNAL_QUEUE_NAME)) - }) - if err != nil { - panic(fmt.Sprintf("failed to register scheduled workflow: %v", err)) - } - getLogger().Info("Registered scheduled workflow", "fqn", fqn, "cron_schedule", registrationParams.cronSchedule) - } - // Register a type-erased version of the durable workflow for recovery typeErasedWrapper := func(ctx context.Context, input any, opts ...workflowOption) (WorkflowHandle[any], error) { typedInput, ok := input.(P) @@ -323,7 +314,15 @@ func WithWorkflow[P any, R any](fn WorkflowFunc[P, R], opts ...workflowRegistrat } return &workflowPollingHandle[any]{workflowID: handle.GetWorkflowID()}, nil } - registerWorkflow(fqn, typeErasedWrapper, registrationParams.maxRetries) + dbosExecutor.RegisterWorkflow(fqn, typeErasedWrapper, registrationParams.maxRetries) + + // If this is a scheduled workflow, register a cron job + if registrationParams.cronSchedule != "" { + if reflect.TypeOf(p) != reflect.TypeOf(time.Time{}) { + panic(fmt.Sprintf("scheduled workflow function must accept a time.Time as input, got %T", p)) + } + dbosExecutor.RegisterScheduledWorkflow(fqn, typeErasedWrapper, registrationParams.cronSchedule, registrationParams.maxRetries) + } return wrappedFunction } @@ -386,10 +385,10 @@ func WithWorkflowMaxRetries(maxRetries int) workflowOption { } } -func runAsWorkflow[P any, R any](ctx context.Context, fn WorkflowFunc[P, R], input P, opts ...workflowOption) (WorkflowHandle[R], error) { +func runAsWorkflow[P any, R any](ctx context.Context, dbosExecutor DBOSExecutor, fn WorkflowFunc[P, R], input P, opts ...workflowOption) (WorkflowHandle[R], error) { // Apply options to build params params := workflowParams{ - applicationVersion: dbos.applicationVersion, + applicationVersion: dbosExecutor.GetApplicationVersion(), } for _, opt := range opts { opt(¶ms) From b66a2aa34642b62762db6a175f66e8df8911eb28 Mon Sep 17 00:00:00 2001 From: maxdml Date: Tue, 29 Jul 2025 08:45:46 -0700 Subject: [PATCH 02/30] update tests --- dbos/admin_server_test.go | 28 +++++-- dbos/dbos.go | 2 - dbos/logger_test.go | 18 +++-- dbos/queues_test.go | 151 ++++++++++++++++++++---------------- dbos/serialization_test.go | 28 +++---- dbos/workflows_test.go | 154 +++++++++++++++++++------------------ 6 files changed, 208 insertions(+), 173 deletions(-) diff --git a/dbos/admin_server_test.go b/dbos/admin_server_test.go index 9f971ffa..7ea15750 100644 --- a/dbos/admin_server_test.go +++ b/dbos/admin_server_test.go @@ -15,23 +15,29 @@ func TestAdminServer(t *testing.T) { t.Run("Admin server is not started by default", func(t *testing.T) { // Ensure clean state - Shutdown() + if dbos != nil { + dbos.Shutdown() + dbos = nil + } - err := Initialize(Config{ + executor, err := Initialize(Config{ DatabaseURL: databaseURL, AppName: "test-app", }) if err != nil { t.Skipf("Failed to initialize DBOS: %v", err) } - err = Launch() + err = executor.Launch() if err != nil { t.Skipf("Failed to initialize DBOS: %v", err) } // Ensure cleanup defer func() { - Shutdown() + if dbos != nil { + dbos.Shutdown() + dbos = nil + } }() // Give time for any startup processes @@ -55,10 +61,13 @@ func TestAdminServer(t *testing.T) { }) t.Run("Admin server endpoints", func(t *testing.T) { - Shutdown() + if dbos != nil { + dbos.Shutdown() + dbos = nil + } // Launch DBOS with admin server once for all endpoint tests - err := Initialize(Config{ + executor, err := Initialize(Config{ DatabaseURL: databaseURL, AppName: "test-app", AdminServer: true, @@ -66,14 +75,17 @@ func TestAdminServer(t *testing.T) { if err != nil { t.Skipf("Failed to initialize DBOS with admin server: %v", err) } - err = Launch() + err = executor.Launch() if err != nil { t.Skipf("Failed to initialize DBOS with admin server: %v", err) } // Ensure cleanup defer func() { - Shutdown() + if dbos != nil { + dbos.Shutdown() + dbos = nil + } }() // Give the server a moment to start diff --git a/dbos/dbos.go b/dbos/dbos.go index acd10f40..bffda31c 100644 --- a/dbos/dbos.go +++ b/dbos/dbos.go @@ -259,8 +259,6 @@ func (e *executor) Shutdown() { if logger != nil { logger = nil } - // XX now responsiblity of the caller right? - e = nil } func GetBinaryHash() (string, error) { diff --git a/dbos/logger_test.go b/dbos/logger_test.go index 2b7eeebb..1f70db0e 100644 --- a/dbos/logger_test.go +++ b/dbos/logger_test.go @@ -11,19 +11,22 @@ func TestLogger(t *testing.T) { databaseURL := getDatabaseURL(t) t.Run("Default logger", func(t *testing.T) { - err := Initialize(Config{ + executor, err := Initialize(Config{ DatabaseURL: databaseURL, AppName: "test-app", }) // Create executor with default logger if err != nil { t.Fatalf("Failed to create executor with default logger: %v", err) } - err = Launch() + err = executor.Launch() if err != nil { t.Fatalf("Failed to launch with default logger: %v", err) } t.Cleanup(func() { - Shutdown() + if dbos != nil { + dbos.Shutdown() + dbos = nil + } }) if logger == nil { @@ -45,7 +48,7 @@ func TestLogger(t *testing.T) { // Add some context to the slog logger slogLogger = slogLogger.With("service", "dbos-test", "environment", "test") - err := Initialize(Config{ + executor, err := Initialize(Config{ DatabaseURL: databaseURL, AppName: "test-app", Logger: slogLogger, @@ -53,12 +56,15 @@ func TestLogger(t *testing.T) { if err != nil { t.Fatalf("Failed to create executor with custom logger: %v", err) } - err = Launch() + err = executor.Launch() if err != nil { t.Fatalf("Failed to launch with custom logger: %v", err) } t.Cleanup(func() { - Shutdown() + if dbos != nil { + dbos.Shutdown() + dbos = nil + } }) if logger == nil { diff --git a/dbos/queues_test.go b/dbos/queues_test.go index 74a11c9c..618ce15f 100644 --- a/dbos/queues_test.go +++ b/dbos/queues_test.go @@ -27,21 +27,13 @@ This suite tests */ var ( - queue = NewWorkflowQueue("test-queue") - queueWf = WithWorkflow(queueWorkflow) - queueWfWithChild = WithWorkflow(queueWorkflowWithChild) - queueWfThatEnqueues = WithWorkflow(queueWorkflowThatEnqueues) + queue = NewWorkflowQueue("test-queue") // Variables for successive enqueue test - dlqEnqueueQueue = NewWorkflowQueue("test-successive-enqueue-queue") - dlqStartEvent = NewEvent() - dlqCompleteEvent = NewEvent() - dlqMaxRetries = 10 - enqueueWorkflowDLQ = WithWorkflow(func(ctx context.Context, input string) (string, error) { - dlqStartEvent.Set() - dlqCompleteEvent.Wait() - return input, nil - }, WithMaxRetries(dlqMaxRetries)) + dlqEnqueueQueue = NewWorkflowQueue("test-successive-enqueue-queue") + dlqStartEvent = NewEvent() + dlqCompleteEvent = NewEvent() + dlqMaxRetries = 10 ) func queueWorkflow(ctx context.Context, input string) (string, error) { @@ -56,40 +48,52 @@ func queueStep(ctx context.Context, input string) (string, error) { return input, nil } -func queueWorkflowWithChild(ctx context.Context, input string) (string, error) { - // Start a child workflow - childHandle, err := queueWf(ctx, input+"-child") - if err != nil { - return "", fmt.Errorf("failed to start child workflow: %v", err) - } - - // Get result from child workflow - childResult, err := childHandle.GetResult(ctx) - if err != nil { - return "", fmt.Errorf("failed to get child result: %v", err) - } - return childResult, nil -} +func TestWorkflowQueues(t *testing.T) { + executor := setupDBOS(t) + + // Setup workflows with executor + queueWf := WithWorkflow(executor, queueWorkflow) + + // Create workflow with child that can call the main workflow + queueWfWithChild := WithWorkflow(executor, func(ctx context.Context, input string) (string, error) { + // Start a child workflow + childHandle, err := queueWf(ctx, input+"-child") + if err != nil { + return "", fmt.Errorf("failed to start child workflow: %v", err) + } -func queueWorkflowThatEnqueues(ctx context.Context, input string) (string, error) { - // Enqueue another workflow to the same queue - enqueuedHandle, err := queueWf(ctx, input+"-enqueued", WithQueue(queue.name)) - if err != nil { - return "", fmt.Errorf("failed to enqueue workflow: %v", err) - } + // Get result from child workflow + childResult, err := childHandle.GetResult(ctx) + if err != nil { + return "", fmt.Errorf("failed to get child result: %v", err) + } - // Get result from the enqueued workflow - enqueuedResult, err := enqueuedHandle.GetResult(ctx) - if err != nil { - return "", fmt.Errorf("failed to get enqueued workflow result: %v", err) - } + return childResult, nil + }) + + // Create workflow that enqueues another workflow + queueWfThatEnqueues := WithWorkflow(executor, func(ctx context.Context, input string) (string, error) { + // Enqueue another workflow to the same queue + enqueuedHandle, err := queueWf(ctx, input+"-enqueued", WithQueue(queue.name)) + if err != nil { + return "", fmt.Errorf("failed to enqueue workflow: %v", err) + } - return enqueuedResult, nil -} + // Get result from the enqueued workflow + enqueuedResult, err := enqueuedHandle.GetResult(ctx) + if err != nil { + return "", fmt.Errorf("failed to get enqueued workflow result: %v", err) + } -func TestWorkflowQueues(t *testing.T) { - setupDBOS(t) + return enqueuedResult, nil + }) + + enqueueWorkflowDLQ := WithWorkflow(executor, func(ctx context.Context, input string) (string, error) { + dlqStartEvent.Set() + dlqCompleteEvent.Wait() + return input, nil + }, WithMaxRetries(dlqMaxRetries)) t.Run("EnqueueWorkflow", func(t *testing.T) { handle, err := queueWf(context.Background(), "test-input", WithQueue(queue.name)) @@ -249,14 +253,23 @@ var ( recoveryStepEvents = make([]*Event, 5) // 5 queued steps recoveryEvent = NewEvent() - recoveryStepWorkflow = WithWorkflow(func(ctx context.Context, i int) (int, error) { +) + +func TestQueueRecovery(t *testing.T) { + executor := setupDBOS(t) + + // Create workflows with executor + var recoveryStepWorkflow func(context.Context, int, ...workflowOption) (WorkflowHandle[int], error) + var recoveryWorkflow func(context.Context, string, ...workflowOption) (WorkflowHandle[[]int], error) + + recoveryStepWorkflow = WithWorkflow(executor, func(ctx context.Context, i int) (int, error) { recoveryStepCounter++ recoveryStepEvents[i].Set() recoveryEvent.Wait() return i, nil }) - recoveryWorkflow = WithWorkflow(func(ctx context.Context, input string) ([]int, error) { + recoveryWorkflow = WithWorkflow(executor, func(ctx context.Context, input string) ([]int, error) { handles := make([]WorkflowHandle[int], 0, 5) // 5 queued steps for i := range 5 { handle, err := recoveryStepWorkflow(ctx, i, WithQueue(recoveryQueue.name)) @@ -276,10 +289,6 @@ var ( } return results, nil }) -) - -func TestQueueRecovery(t *testing.T) { - setupDBOS(t) queuedSteps := 5 @@ -377,7 +386,13 @@ var ( workflowEvent1 = NewEvent() workflowEvent2 = NewEvent() workflowDoneEvent = NewEvent() - globalConcurrencyWorkflow = WithWorkflow(func(ctx context.Context, input string) (string, error) { +) + +func TestGlobalConcurrency(t *testing.T) { + executor := setupDBOS(t) + + // Create workflow with executor + globalConcurrencyWorkflow := WithWorkflow(executor, func(ctx context.Context, input string) (string, error) { switch input { case "workflow1": workflowEvent1.Set() @@ -387,10 +402,6 @@ var ( } return input, nil }) -) - -func TestGlobalConcurrency(t *testing.T) { - setupDBOS(t) // Enqueue two workflows handle1, err := globalConcurrencyWorkflow(context.Background(), "workflow1", WithQueue(globalConcurrencyQueue.name)) @@ -459,16 +470,18 @@ var ( NewEvent(), NewEvent(), } - blockingWf = WithWorkflow(func(ctx context.Context, i int) (int, error) { +) + +func TestWorkerConcurrency(t *testing.T) { + executor := setupDBOS(t) + + // Create workflow with executor + blockingWf := WithWorkflow(executor, func(ctx context.Context, i int) (int, error) { // Simulate a blocking operation startEvents[i].Set() completeEvents[i].Wait() return i, nil }) -) - -func TestWorkerConcurrency(t *testing.T) { - setupDBOS(t) // First enqueue four blocking workflows handle1, err := blockingWf(context.Background(), 0, WithQueue(workerConcurrencyQueue.name), WithWorkflowID("worker-cc-wf-1")) @@ -606,20 +619,22 @@ var ( workerConcurrencyRecoveryStartEvent2 = NewEvent() workerConcurrencyRecoveryCompleteEvent1 = NewEvent() workerConcurrencyRecoveryCompleteEvent2 = NewEvent() - workerConcurrencyRecoveryBlockingWf1 = WithWorkflow(func(ctx context.Context, input string) (string, error) { +) + +func TestWorkerConcurrencyXRecovery(t *testing.T) { + executor := setupDBOS(t) + + // Create workflows with executor + workerConcurrencyRecoveryBlockingWf1 := WithWorkflow(executor, func(ctx context.Context, input string) (string, error) { workerConcurrencyRecoveryStartEvent1.Set() workerConcurrencyRecoveryCompleteEvent1.Wait() return input, nil }) - workerConcurrencyRecoveryBlockingWf2 = WithWorkflow(func(ctx context.Context, input string) (string, error) { + workerConcurrencyRecoveryBlockingWf2 := WithWorkflow(executor, func(ctx context.Context, input string) (string, error) { workerConcurrencyRecoveryStartEvent2.Set() workerConcurrencyRecoveryCompleteEvent2.Wait() return input, nil }) -) - -func TestWorkerConcurrencyXRecovery(t *testing.T) { - setupDBOS(t) // Enqueue two workflows on a queue with worker concurrency = 1 handle1, err := workerConcurrencyRecoveryBlockingWf1(context.Background(), "workflow1", WithQueue(workerConcurrencyRecoveryQueue.name), WithWorkflowID("worker-cc-x-recovery-wf-1")) @@ -715,8 +730,7 @@ func TestWorkerConcurrencyXRecovery(t *testing.T) { } var ( - rateLimiterQueue = NewWorkflowQueue("test-rate-limiter-queue", WithRateLimiter(&RateLimiter{Limit: 5, Period: 1.8})) - rateLimiterWorkflow = WithWorkflow(rateLimiterTestWorkflow) + rateLimiterQueue = NewWorkflowQueue("test-rate-limiter-queue", WithRateLimiter(&RateLimiter{Limit: 5, Period: 1.8})) ) func rateLimiterTestWorkflow(ctx context.Context, _ string) (time.Time, error) { @@ -724,7 +738,10 @@ func rateLimiterTestWorkflow(ctx context.Context, _ string) (time.Time, error) { } func TestQueueRateLimiter(t *testing.T) { - setupDBOS(t) + executor := setupDBOS(t) + + // Create workflow with executor + rateLimiterWorkflow := WithWorkflow(executor, rateLimiterTestWorkflow) limit := 5 period := 1.8 diff --git a/dbos/serialization_test.go b/dbos/serialization_test.go index 7e6acc39..8450b0ae 100644 --- a/dbos/serialization_test.go +++ b/dbos/serialization_test.go @@ -17,11 +17,6 @@ import ( [x] Set/get event with user defined types */ -var ( - builtinWf = WithWorkflow(encodingWorkflowBuiltinTypes) - structWf = WithWorkflow(encodingWorkflowStruct) -) - // Builtin types func encodingStepBuiltinTypes(_ context.Context, input int) (int, error) { return input, errors.New("step error") @@ -68,7 +63,11 @@ func encodingStepStruct(ctx context.Context, input StepInputStruct) (StepOutputS } func TestWorkflowEncoding(t *testing.T) { - setupDBOS(t) + executor := setupDBOS(t) + + // Create workflows with executor + builtinWf := WithWorkflow(executor, encodingWorkflowBuiltinTypes) + structWf := WithWorkflow(executor, encodingWorkflowStruct) t.Run("BuiltinTypes", func(t *testing.T) { // Test a workflow that uses a built-in type (string) @@ -321,10 +320,11 @@ func setEventUserDefinedTypeWorkflow(ctx context.Context, input string) (string, return "user-defined-event-set", nil } -var setEventUserDefinedTypeWf = WithWorkflow(setEventUserDefinedTypeWorkflow) - func TestSetEventSerialize(t *testing.T) { - setupDBOS(t) + executor := setupDBOS(t) + + // Create workflow with executor + setEventUserDefinedTypeWf := WithWorkflow(executor, setEventUserDefinedTypeWorkflow) t.Run("SetEventUserDefinedType", func(t *testing.T) { // Start a workflow that sets an event with a user-defined type @@ -374,7 +374,6 @@ func TestSetEventSerialize(t *testing.T) { }) } - func sendUserDefinedTypeWorkflow(ctx context.Context, destinationID string) (string, error) { // Create an instance of our user-defined type inside the workflow sendData := UserDefinedEventData{ @@ -411,11 +410,12 @@ func recvUserDefinedTypeWorkflow(ctx context.Context, input string) (UserDefined return result, err } -var sendUserDefinedTypeWf = WithWorkflow(sendUserDefinedTypeWorkflow) -var recvUserDefinedTypeWf = WithWorkflow(recvUserDefinedTypeWorkflow) - func TestSendSerialize(t *testing.T) { - setupDBOS(t) + executor := setupDBOS(t) + + // Create workflows with executor + sendUserDefinedTypeWf := WithWorkflow(executor, sendUserDefinedTypeWorkflow) + recvUserDefinedTypeWf := WithWorkflow(executor, recvUserDefinedTypeWorkflow) t.Run("SendUserDefinedType", func(t *testing.T) { // Start a receiver workflow first diff --git a/dbos/workflows_test.go b/dbos/workflows_test.go index 7fad1361..0789df9f 100644 --- a/dbos/workflows_test.go +++ b/dbos/workflows_test.go @@ -24,29 +24,6 @@ import ( // Global counter for idempotency testing var idempotencyCounter int64 -var ( - simpleWf = WithWorkflow(simpleWorkflow) - simpleWfError = WithWorkflow(simpleWorkflowError) - simpleWfWithStep = WithWorkflow(simpleWorkflowWithStep) - simpleWfWithStepError = WithWorkflow(simpleWorkflowWithStepError) - // struct methods - s = workflowStruct{} - simpleWfStruct = WithWorkflow(s.simpleWorkflow) - simpleWfValue = WithWorkflow(s.simpleWorkflowValue) - // interface method workflow - workflowIface TestWorkflowInterface = &workflowImplementation{ - field: "example", - } - simpleWfIface = WithWorkflow(workflowIface.Execute) - // Generic workflow - wfInt = WithWorkflow(Identity[string]) // FIXME make this an int eventually - // Closure with captured state - prefix = "hello-" - wfClose = WithWorkflow(func(ctx context.Context, in string) (string, error) { - return prefix + in, nil - }) -) - func simpleWorkflow(ctxt context.Context, input string) (string, error) { return input, nil } @@ -108,14 +85,34 @@ func Identity[T any](ctx context.Context, in T) (T, error) { return in, nil } -var ( - anonymousWf = WithWorkflow(func(ctx context.Context, in string) (string, error) { +func TestWorkflowsWrapping(t *testing.T) { + executor := setupDBOS(t) + + // Setup workflows with executor + simpleWf := WithWorkflow(executor, simpleWorkflow) + simpleWfError := WithWorkflow(executor, simpleWorkflowError) + simpleWfWithStep := WithWorkflow(executor, simpleWorkflowWithStep) + simpleWfWithStepError := WithWorkflow(executor, simpleWorkflowWithStepError) + // struct methods + s := workflowStruct{} + simpleWfStruct := WithWorkflow(executor, s.simpleWorkflow) + simpleWfValue := WithWorkflow(executor, s.simpleWorkflowValue) + // interface method workflow + workflowIface := TestWorkflowInterface(&workflowImplementation{ + field: "example", + }) + simpleWfIface := WithWorkflow(executor, workflowIface.Execute) + // Generic workflow + wfInt := WithWorkflow(executor, Identity[string]) // FIXME make this an int eventually + // Closure with captured state + prefix := "hello-" + wfClose := WithWorkflow(executor, func(ctx context.Context, in string) (string, error) { + return prefix + in, nil + }) + // Anonymous workflow + anonymousWf := WithWorkflow(executor, func(ctx context.Context, in string) (string, error) { return "anonymous-" + in, nil }) -) - -func TestWorkflowsWrapping(t *testing.T) { - setupDBOS(t) type testCase struct { name string @@ -324,13 +321,12 @@ func stepRetryWorkflow(ctx context.Context, input string) (string, error) { WithMaxInterval(10*time.Millisecond)) } -var ( - stepWithinAStepWf = WithWorkflow(stepWithinAStepWorkflow) - stepRetryWf = WithWorkflow(stepRetryWorkflow) -) - func TestSteps(t *testing.T) { - setupDBOS(t) + executor := setupDBOS(t) + + // Create workflows with executor + stepWithinAStepWf := WithWorkflow(executor, stepWithinAStepWorkflow) + stepRetryWf := WithWorkflow(executor, stepRetryWorkflow) t.Run("StepsMustRunInsideWorkflows", func(t *testing.T) { ctx := context.Background() @@ -452,8 +448,14 @@ func TestSteps(t *testing.T) { }) } -var ( - childWf = WithWorkflow(func(ctx context.Context, i int) (string, error) { +// Functions that create child workflows - moved to test function where executor is available + +// TODO Check timeouts behaviors for parents and children (e.g. awaited cancelled, etc) +func TestChildWorkflow(t *testing.T) { + executor := setupDBOS(t) + + // Create child workflows with executor + childWf := WithWorkflow(executor, func(ctx context.Context, i int) (string, error) { workflowID, err := GetWorkflowID(ctx) if err != nil { return "", fmt.Errorf("failed to get workflow ID: %v", err) @@ -465,7 +467,8 @@ var ( // XXX right now the steps of a child workflow start with an incremented step ID, because the first step ID is allocated to the child workflow return RunAsStep(ctx, simpleStep, "") }) - parentWf = WithWorkflow(func(ctx context.Context, i int) (string, error) { + + parentWf := WithWorkflow(executor, func(ctx context.Context, i int) (string, error) { workflowID, err := GetWorkflowID(ctx) if err != nil { return "", fmt.Errorf("failed to get workflow ID: %v", err) @@ -490,7 +493,8 @@ var ( } return childHandle.GetResult(ctx) }) - grandParentWf = WithWorkflow(func(ctx context.Context, _ string) (string, error) { + + grandParentWf := WithWorkflow(executor, func(ctx context.Context, _ string) (string, error) { for i := range 3 { workflowID, err := GetWorkflowID(ctx) if err != nil { @@ -529,11 +533,6 @@ var ( return "", nil }) -) - -// TODO Check timeouts behaviors for parents and children (e.g. awaited cancelled, etc) -func TestChildWorkflow(t *testing.T) { - setupDBOS(t) t.Run("ChildWorkflowIDPattern", func(t *testing.T) { h, err := grandParentWf(context.Background(), "") @@ -547,10 +546,7 @@ func TestChildWorkflow(t *testing.T) { }) } -var ( - idempotencyWf = WithWorkflow(idempotencyWorkflow) - idempotencyWfWithStep = WithWorkflow(idempotencyWorkflowWithStep) -) +// Idempotency workflows moved to test functions func idempotencyWorkflow(ctx context.Context, input string) (string, error) { incrementCounter(ctx, 1) @@ -574,7 +570,8 @@ func idempotencyWorkflowWithStep(ctx context.Context, input string) (int64, erro } func TestWorkflowIdempotency(t *testing.T) { - setupDBOS(t) + executor := setupDBOS(t) + idempotencyWf := WithWorkflow(executor, idempotencyWorkflow) t.Run("WorkflowExecutedOnlyOnce", func(t *testing.T) { idempotencyCounter = 0 @@ -622,7 +619,8 @@ func TestWorkflowIdempotency(t *testing.T) { } func TestWorkflowRecovery(t *testing.T) { - setupDBOS(t) + executor := setupDBOS(t) + idempotencyWfWithStep := WithWorkflow(executor, idempotencyWorkflowWithStep) t.Run("RecoveryResumeWhereItLeftOff", func(t *testing.T) { // Reset the global counter idempotencyCounter = 0 @@ -700,8 +698,6 @@ func TestWorkflowRecovery(t *testing.T) { var ( maxRecoveryAttempts = 20 - deadLetterQueueWf = WithWorkflow(deadLetterQueueWorkflow, WithMaxRetries(maxRecoveryAttempts)) - infiniteDeadLetterQueueWf = WithWorkflow(infiniteDeadLetterQueueWorkflow, WithMaxRetries(-1)) // A negative value means infinite retries deadLetterQueueStartEvent *Event deadLetterQueueEvent *Event recoveryCount int64 @@ -721,7 +717,9 @@ func infiniteDeadLetterQueueWorkflow(ctx context.Context, input string) (int, er return 0, nil } func TestWorkflowDeadLetterQueue(t *testing.T) { - setupDBOS(t) + executor := setupDBOS(t) + deadLetterQueueWf := WithWorkflow(executor, deadLetterQueueWorkflow, WithMaxRetries(maxRecoveryAttempts)) + infiniteDeadLetterQueueWf := WithWorkflow(executor, infiniteDeadLetterQueueWorkflow, WithMaxRetries(-1)) // A negative value means infinite retries t.Run("DeadLetterQueueBehavior", func(t *testing.T) { deadLetterQueueEvent = NewEvent() @@ -890,7 +888,11 @@ func TestWorkflowDeadLetterQueue(t *testing.T) { var ( counter = 0 counter1Ch = make(chan time.Time, 100) - _ = WithWorkflow(func(ctx context.Context, scheduledTime time.Time) (string, error) { +) + +func TestScheduledWorkflows(t *testing.T) { + executor := setupDBOS(t) + _ = WithWorkflow(executor, func(ctx context.Context, scheduledTime time.Time) (string, error) { startTime := time.Now() counter++ if counter == 10 { @@ -902,10 +904,6 @@ var ( } return fmt.Sprintf("Scheduled workflow scheduled at time %v and executed at time %v", scheduledTime, startTime), nil }, WithSchedule("* * * * * *")) // Every second -) - -func TestScheduledWorkflows(t *testing.T) { - setupDBOS(t) // Helper function to collect execution times collectExecutionTimes := func(ch chan time.Time, target int, timeout time.Duration) ([]time.Time, error) { @@ -956,7 +954,7 @@ func TestScheduledWorkflows(t *testing.T) { // Stop the workflowScheduler and check if it stops executing currentCounter := counter - workflowScheduler.Stop() + executor.GetWorkflowScheduler().Stop() time.Sleep(3 * time.Second) // Wait a bit to ensure no more executions if counter >= currentCounter+2 { t.Fatalf("Scheduled workflow continued executing after stopping scheduler: %d (expected < %d)", counter, currentCounter+2) @@ -965,17 +963,9 @@ func TestScheduledWorkflows(t *testing.T) { } var ( - sendWf = WithWorkflow(sendWorkflow) - receiveWf = WithWorkflow(receiveWorkflow) - receiveWfCoordinated = WithWorkflow(receiveWorkflowCoordinated) - sendStructWf = WithWorkflow(sendStructWorkflow) - receiveStructWf = WithWorkflow(receiveStructWorkflow) - sendIdempotencyWf = WithWorkflow(sendIdempotencyWorkflow) sendIdempotencyEvent = NewEvent() - recvIdempotencyWf = WithWorkflow(receiveIdempotencyWorkflow) receiveIdempotencyStartEvent = NewEvent() receiveIdempotencyStopEvent = NewEvent() - sendWithinStepWf = WithWorkflow(workflowThatCallsSendInStep) numConcurrentRecvWfs = 5 concurrentRecvReadyEvents = make([]*Event, numConcurrentRecvWfs) concurrentRecvStartEvent = NewEvent() @@ -1087,7 +1077,17 @@ type sendRecvType struct { } func TestSendRecv(t *testing.T) { - setupDBOS(t) + executor := setupDBOS(t) + + // Register all send/recv workflows with executor + sendWf := WithWorkflow(executor, sendWorkflow) + receiveWf := WithWorkflow(executor, receiveWorkflow) + receiveWfCoordinated := WithWorkflow(executor, receiveWorkflowCoordinated) + sendStructWf := WithWorkflow(executor, sendStructWorkflow) + receiveStructWf := WithWorkflow(executor, receiveStructWorkflow) + sendIdempotencyWf := WithWorkflow(executor, sendIdempotencyWorkflow) + recvIdempotencyWf := WithWorkflow(executor, receiveIdempotencyWorkflow) + sendWithinStepWf := WithWorkflow(executor, workflowThatCallsSendInStep) t.Run("SendRecvSuccess", func(t *testing.T) { // Start the receive workflow @@ -1455,11 +1455,6 @@ func TestSendRecv(t *testing.T) { } var ( - setEventWf = WithWorkflow(setEventWorkflow) - getEventWf = WithWorkflow(getEventWorkflow) - setTwoEventsWf = WithWorkflow(setTwoEventsWorkflow) - setEventIdempotencyWf = WithWorkflow(setEventIdempotencyWorkflow) - getEventIdempotencyWf = WithWorkflow(getEventIdempotencyWorkflow) setEventIdempotencyEvent = NewEvent() getEventStartIdempotencyEvent = NewEvent() getEventStopIdempotencyEvent = NewEvent() @@ -1534,7 +1529,14 @@ func getEventIdempotencyWorkflow(ctx context.Context, input setEventWorkflowInpu } func TestSetGetEvent(t *testing.T) { - setupDBOS(t) + executor := setupDBOS(t) + + // Register all set/get event workflows with executor + setEventWf := WithWorkflow(executor, setEventWorkflow) + getEventWf := WithWorkflow(executor, getEventWorkflow) + setTwoEventsWf := WithWorkflow(executor, setTwoEventsWorkflow) + setEventIdempotencyWf := WithWorkflow(executor, setEventIdempotencyWorkflow) + getEventIdempotencyWf := WithWorkflow(executor, getEventIdempotencyWorkflow) t.Run("SetGetEventFromWorkflow", func(t *testing.T) { // Clear the signal event before starting @@ -1815,7 +1817,6 @@ func TestSetGetEvent(t *testing.T) { } var ( - sleepRecoveryWf = WithWorkflow(sleepRecoveryWorkflow) sleepStartEvent *Event sleepStopEvent *Event ) @@ -1832,7 +1833,8 @@ func sleepRecoveryWorkflow(ctx context.Context, duration time.Duration) (time.Du } func TestSleep(t *testing.T) { - setupDBOS(t) + executor := setupDBOS(t) + sleepRecoveryWf := WithWorkflow(executor, sleepRecoveryWorkflow) t.Run("SleepDurableRecovery", func(t *testing.T) { sleepStartEvent = NewEvent() From e34671262aa63152481829e47de264a1a3dbdc9f Mon Sep 17 00:00:00 2001 From: maxdml Date: Tue, 29 Jul 2025 20:33:09 -0700 Subject: [PATCH 03/30] wip --- dbos/admin_server.go | 4 +- dbos/admin_server_test.go | 39 +++--- dbos/dbos.go | 132 ++++++++++++++---- dbos/dbos_test.go | 4 +- dbos/logger_test.go | 18 ++- dbos/queue.go | 17 ++- dbos/queues_test.go | 54 ++++---- dbos/recovery.go | 17 +-- dbos/serialization_test.go | 26 ++-- dbos/system_database.go | 66 ++++----- dbos/utils_test.go | 42 +++--- dbos/workflow.go | 271 ++++++++++++++++++++++--------------- dbos/workflows_test.go | 144 ++++++++++---------- 13 files changed, 478 insertions(+), 356 deletions(-) diff --git a/dbos/admin_server.go b/dbos/admin_server.go index b942c13a..364f3960 100644 --- a/dbos/admin_server.go +++ b/dbos/admin_server.go @@ -25,7 +25,7 @@ type queueMetadata struct { RateLimit *RateLimiter `json:"rateLimit,omitempty"` } -func newAdminServer(port int) *adminServer { +func newAdminServer(dbosCtx *dbosContext, port int) *adminServer { mux := http.NewServeMux() // Health endpoint @@ -50,7 +50,7 @@ func newAdminServer(port int) *adminServer { getLogger().Info("Recovering workflows for executors", "executors", executorIDs) - handles, err := recoverPendingWorkflows(r.Context(), executorIDs) + handles, err := recoverPendingWorkflows(dbosCtx, executorIDs) if err != nil { getLogger().Error("Error recovering workflows", "error", err) http.Error(w, fmt.Sprintf("Recovery failed: %v", err), http.StatusInternalServerError) diff --git a/dbos/admin_server_test.go b/dbos/admin_server_test.go index 7ea15750..1d9fb117 100644 --- a/dbos/admin_server_test.go +++ b/dbos/admin_server_test.go @@ -14,13 +14,8 @@ func TestAdminServer(t *testing.T) { databaseURL := getDatabaseURL(t) t.Run("Admin server is not started by default", func(t *testing.T) { - // Ensure clean state - if dbos != nil { - dbos.Shutdown() - dbos = nil - } - executor, err := Initialize(Config{ + executor, err := NewDBOSContext(Config{ DatabaseURL: databaseURL, AppName: "test-app", }) @@ -34,10 +29,9 @@ func TestAdminServer(t *testing.T) { // Ensure cleanup defer func() { - if dbos != nil { - dbos.Shutdown() - dbos = nil - } + if executor != nil { + executor.Shutdown() + } }() // Give time for any startup processes @@ -51,23 +45,22 @@ func TestAdminServer(t *testing.T) { } // Verify the DBOS executor doesn't have an admin server instance - if dbos == nil { + if executor == nil { t.Fatal("Expected DBOS instance to be created") } - if dbos.adminServer != nil { + exec := executor.(*dbosContext) + if exec.adminServer != nil { t.Error("Expected admin server to be nil when not configured") } }) t.Run("Admin server endpoints", func(t *testing.T) { - if dbos != nil { - dbos.Shutdown() - dbos = nil - } + // Clean up any existing instance + // (This will be handled by the individual executor cleanup) // Launch DBOS with admin server once for all endpoint tests - executor, err := Initialize(Config{ + executor, err := NewDBOSContext(Config{ DatabaseURL: databaseURL, AppName: "test-app", AdminServer: true, @@ -82,21 +75,21 @@ func TestAdminServer(t *testing.T) { // Ensure cleanup defer func() { - if dbos != nil { - dbos.Shutdown() - dbos = nil - } + if executor != nil { + executor.Shutdown() + } }() // Give the server a moment to start time.Sleep(100 * time.Millisecond) // Verify the DBOS executor has an admin server instance - if dbos == nil { + if executor == nil { t.Fatal("Expected DBOS instance to be created") } - if dbos.adminServer == nil { + exec := executor.(*dbosContext) + if exec.adminServer == nil { t.Fatal("Expected admin server to be created in DBOS instance") } diff --git a/dbos/dbos.go b/dbos/dbos.go index bffda31c..2ffb8fa4 100644 --- a/dbos/dbos.go +++ b/dbos/dbos.go @@ -22,7 +22,7 @@ var ( var logger *slog.Logger // Global because accessed everywhere inside the library func getLogger() *slog.Logger { - if dbos == nil || logger == nil { + if logger == nil { return slog.New(slog.NewTextHandler(os.Stderr, nil)) } return logger @@ -60,60 +60,139 @@ func processConfig(inputConfig *Config) (*Config, error) { return dbosConfig, nil } -type DBOSExecutor interface { +type DBOSContext interface { + context.Context // Standard Go context behavior + + // Context Lifecycle Launch() error Shutdown() + WithValue(key, val any) DBOSContext + // Workflow registration RegisterWorkflow(fqn string, fn typedErasedWorkflowWrapperFunc, maxRetries int) RegisterScheduledWorkflow(fqn string, fn typedErasedWorkflowWrapperFunc, cronSchedule string, maxRetries int) + // Workflow operations + RunAsStep(fn TypeErasedStepFunc, input any, stepName string, opts ...stepOption) (any, error) + Send(input WorkflowSendInputInternal) error + Recv(input WorkflowRecvInput) (any, error) + SetEvent(input WorkflowSetEventInputInternal) error + GetEvent(input WorkflowGetEventInput) (any, error) + Sleep(duration time.Duration) (time.Duration, error) + + // Workflow management + RetrieveWorkflow(workflowIDs []string) ([]WorkflowStatus, error) + CheckChildWorkflow(parentWorkflowID string, stepCounter int) (*string, error) + InsertWorkflowStatus(status WorkflowStatus, maxRetries int) (*insertWorkflowResult, error) + RecordChildWorkflow(input recordChildWorkflowDBInput) error + UpdateWorkflowOutcome(workflowID string, status WorkflowStatusType, err error, output any) error + + // Context operations + GetWorkflowID() (string, error) + + // Accessors GetWorkflowScheduler() *cron.Cron GetApplicationVersion() string + GetSystemDB() SystemDatabase + GetContext() context.Context + GetExecutorID() string + GetApplicationID() string + GetWorkflowWg() *sync.WaitGroup } -var dbos *executor // DBOS singleton instance - -type executor struct { +type dbosContext struct { + ctx context.Context // Embedded context for standard behavior systemDB SystemDatabase adminServer *adminServer config *Config + // Queue runner context and cancel function queueRunnerCtx context.Context queueRunnerCancelFunc context.CancelFunc queueRunnerDone chan struct{} + // Application metadata applicationVersion string applicationID string executorID string + // Wait group for workflow goroutines workflowsWg *sync.WaitGroup + // Workflow registry workflowRegistry map[string]workflowRegistryEntry - workflowRegMutex sync.RWMutex + workflowRegMutex *sync.RWMutex + // Workflow scheduler workflowScheduler *cron.Cron } -func (e *executor) GetWorkflowScheduler() *cron.Cron { +// Implement contex.Context interface methods +func (e *dbosContext) Deadline() (deadline time.Time, ok bool) { + return e.ctx.Deadline() +} + +func (e *dbosContext) Done() <-chan struct{} { + return e.ctx.Done() +} + +func (e *dbosContext) Err() error { + return e.ctx.Err() +} + +func (e *dbosContext) Value(key any) any { + return e.ctx.Value(key) +} + +// Create a new context +// This is intended for workflow contexts and step contexts +// Hence we only set the relevant fields +func (e *dbosContext) WithValue(key, val any) DBOSContext { + return &dbosContext{ + ctx: context.WithValue(e.ctx, key, val), + systemDB: e.systemDB, + applicationVersion: e.applicationVersion, + executorID: e.executorID, + applicationID: e.applicationID, + workflowsWg: e.workflowsWg, + } +} + +func (e *dbosContext) GetContext() context.Context { + return e.ctx +} + +func (e *dbosContext) GetWorkflowScheduler() *cron.Cron { if e.workflowScheduler == nil { e.workflowScheduler = cron.New(cron.WithSeconds()) } return e.workflowScheduler } -func (e *executor) GetApplicationVersion() string { +func (e *dbosContext) GetApplicationVersion() string { return e.applicationVersion } -// TODO: use a normal builder pattern name (NewDBOSExecutor) -func Initialize(inputConfig Config) (DBOSExecutor, error) { - if dbos != nil { - fmt.Println("warning: DBOS instance already initialized, skipping re-initialization") - return nil, newInitializationError("DBOS already initialized") - } +func (e *dbosContext) GetSystemDB() SystemDatabase { + return e.systemDB +} + +func (e *dbosContext) GetExecutorID() string { + return e.executorID +} - initExecutor := &executor{ +func (e *dbosContext) GetApplicationID() string { + return e.applicationID +} + +func (e *dbosContext) GetWorkflowWg() *sync.WaitGroup { + return e.workflowsWg +} + +func NewDBOSContext(inputConfig Config) (DBOSContext, error) { + initExecutor := &dbosContext{ workflowsWg: &sync.WaitGroup{}, + ctx: context.Background(), } // Load & process the configuration @@ -156,29 +235,27 @@ func Initialize(inputConfig Config) (DBOSExecutor, error) { // Initialize the workflow registry initExecutor.workflowRegistry = make(map[string]workflowRegistryEntry) - // Set the global dbos instance - dbos = initExecutor - return initExecutor, nil } -func (e *executor) Launch() error { +func (e *dbosContext) Launch() error { // Start the system database e.systemDB.Launch(context.Background()) // Start the admin server if configured if e.config.AdminServer { - adminServer := newAdminServer(_DEFAULT_ADMIN_SERVER_PORT) + adminServer := newAdminServer(e, _DEFAULT_ADMIN_SERVER_PORT) err := adminServer.Start() if err != nil { logger.Error("Failed to start admin server", "error", err) return newInitializationError(fmt.Sprintf("failed to start admin server: %v", err)) } logger.Info("Admin server started", "port", _DEFAULT_ADMIN_SERVER_PORT) - dbos.adminServer = adminServer + e.adminServer = adminServer } // Create context with cancel function for queue runner + // XXX this can now be a cancel function on the executor itself? ctx, cancel := context.WithCancel(context.Background()) e.queueRunnerCtx = ctx e.queueRunnerCancelFunc = cancel @@ -187,7 +264,7 @@ func (e *executor) Launch() error { // Start the queue runner in a goroutine go func() { defer close(e.queueRunnerDone) - queueRunner(ctx) + queueRunner(e) }() logger.Info("Queue runner started") @@ -198,7 +275,7 @@ func (e *executor) Launch() error { } // Run a round of recovery on the local executor - recoveryHandles, err := recoverPendingWorkflows(context.Background(), []string{e.executorID}) // XXX maybe use the queue runner context here to allow Shutdown to cancel it? + recoveryHandles, err := recoverPendingWorkflows(e, []string{e.executorID}) if err != nil { return newInitializationError(fmt.Sprintf("failed to recover pending workflows during launch: %v", err)) } @@ -210,13 +287,8 @@ func (e *executor) Launch() error { return nil } -func (e *executor) Shutdown() { - if e == nil { - fmt.Println("DBOS instance is nil, cannot shutdown") - return - } - - // XXX is there a way to ensure all workflows goroutine are done before closing? +func (e *dbosContext) Shutdown() { + // Wait for all workflows to finish e.workflowsWg.Wait() // Cancel the context to stop the queue runner diff --git a/dbos/dbos_test.go b/dbos/dbos_test.go index 4e095c6c..370c1c51 100644 --- a/dbos/dbos_test.go +++ b/dbos/dbos_test.go @@ -12,7 +12,7 @@ func TestConfigValidationErrorTypes(t *testing.T) { DatabaseURL: databaseURL, } - _, err := Initialize(config) + _, err := NewDBOSContext(config) if err == nil { t.Fatal("expected error when app name is missing, but got none") } @@ -37,7 +37,7 @@ func TestConfigValidationErrorTypes(t *testing.T) { AppName: "test-app", } - _, err := Initialize(config) + _, err := NewDBOSContext(config) if err == nil { t.Fatal("expected error when database URL is missing, but got none") } diff --git a/dbos/logger_test.go b/dbos/logger_test.go index 1f70db0e..d33b8efb 100644 --- a/dbos/logger_test.go +++ b/dbos/logger_test.go @@ -11,7 +11,7 @@ func TestLogger(t *testing.T) { databaseURL := getDatabaseURL(t) t.Run("Default logger", func(t *testing.T) { - executor, err := Initialize(Config{ + executor, err := NewDBOSContext(Config{ DatabaseURL: databaseURL, AppName: "test-app", }) // Create executor with default logger @@ -23,10 +23,9 @@ func TestLogger(t *testing.T) { t.Fatalf("Failed to launch with default logger: %v", err) } t.Cleanup(func() { - if dbos != nil { - dbos.Shutdown() - dbos = nil - } + if executor != nil { + executor.Shutdown() + } }) if logger == nil { @@ -48,7 +47,7 @@ func TestLogger(t *testing.T) { // Add some context to the slog logger slogLogger = slogLogger.With("service", "dbos-test", "environment", "test") - executor, err := Initialize(Config{ + executor, err := NewDBOSContext(Config{ DatabaseURL: databaseURL, AppName: "test-app", Logger: slogLogger, @@ -61,10 +60,9 @@ func TestLogger(t *testing.T) { t.Fatalf("Failed to launch with custom logger: %v", err) } t.Cleanup(func() { - if dbos != nil { - dbos.Shutdown() - dbos = nil - } + if executor != nil { + executor.Shutdown() + } }) if logger == nil { diff --git a/dbos/queue.go b/dbos/queue.go index ecb48ce8..43a39f32 100644 --- a/dbos/queue.go +++ b/dbos/queue.go @@ -73,10 +73,7 @@ func WithMaxTasksPerIteration(maxTasks int) queueOption { // NewWorkflowQueue creates a new workflow queue with optional configuration func NewWorkflowQueue(name string, options ...queueOption) WorkflowQueue { - if dbos != nil { - getLogger().Warn("NewWorkflowQueue called after DBOS initialization, dynamic registration is not supported") - return WorkflowQueue{} - } + // TODO: Add runtime check for post-initialization registration if needed if _, exists := workflowQueueRegistry[name]; exists { panic(newConflictingRegistrationError(name)) } @@ -102,7 +99,7 @@ func NewWorkflowQueue(name string, options ...queueOption) WorkflowQueue { return q } -func queueRunner(ctx context.Context) { +func queueRunner(executor *dbosContext) { const ( baseInterval = 1.0 // Base interval in seconds minInterval = 1.0 // Minimum polling interval in seconds @@ -122,7 +119,7 @@ func queueRunner(ctx context.Context) { for queueName, queue := range workflowQueueRegistry { getLogger().Debug("Processing queue", "queue_name", queueName) // Call DequeueWorkflows for each queue - dequeuedWorkflows, err := dbos.systemDB.DequeueWorkflows(ctx, queue) + dequeuedWorkflows, err := executor.systemDB.DequeueWorkflows(executor.GetContext(), queue, executor.executorID, executor.applicationVersion) if err != nil { if pgErr, ok := err.(*pgconn.PgError); ok { switch pgErr.Code { @@ -143,7 +140,7 @@ func queueRunner(ctx context.Context) { } for _, workflow := range dequeuedWorkflows { // Find the workflow in the registry - registeredWorkflow, exists := dbos.workflowRegistry[workflow.name] + registeredWorkflow, exists := executor.workflowRegistry[workflow.name] if !exists { getLogger().Error("workflow function not found in registry", "workflow_name", workflow.name) continue @@ -165,7 +162,9 @@ func queueRunner(ctx context.Context) { } } - _, err := registeredWorkflow.wrappedFunction(ctx, input, WithWorkflowID(workflow.id)) + // Create a workflow context from the executor context + workflowCtx := executor.WithValue(context.Background(), nil) + _, err := registeredWorkflow.wrappedFunction(workflowCtx, input, WithWorkflowID(workflow.id)) if err != nil { getLogger().Error("Error running queued workflow", "error", err) } @@ -187,7 +186,7 @@ func queueRunner(ctx context.Context) { // Sleep with jittered interval, but allow early exit on context cancellation select { - case <-ctx.Done(): + case <-executor.GetContext().Done(): getLogger().Info("Queue runner stopping due to context cancellation") return case <-time.After(sleepDuration): diff --git a/dbos/queues_test.go b/dbos/queues_test.go index 618ce15f..d77a8c19 100644 --- a/dbos/queues_test.go +++ b/dbos/queues_test.go @@ -36,7 +36,7 @@ var ( dlqMaxRetries = 10 ) -func queueWorkflow(ctx context.Context, input string) (string, error) { +func queueWorkflow(ctx DBOSContext, input string) (string, error) { step1, err := RunAsStep(ctx, queueStep, input) if err != nil { return "", fmt.Errorf("failed to run step: %v", err) @@ -48,15 +48,14 @@ func queueStep(ctx context.Context, input string) (string, error) { return input, nil } - func TestWorkflowQueues(t *testing.T) { executor := setupDBOS(t) - + // Setup workflows with executor - queueWf := WithWorkflow(executor, queueWorkflow) - + queueWf := RegisterWorkflow(executor, queueWorkflow) + // Create workflow with child that can call the main workflow - queueWfWithChild := WithWorkflow(executor, func(ctx context.Context, input string) (string, error) { + queueWfWithChild := RegisterWorkflow(executor, func(ctx context.Context, input string) (string, error) { // Start a child workflow childHandle, err := queueWf(ctx, input+"-child") if err != nil { @@ -71,9 +70,9 @@ func TestWorkflowQueues(t *testing.T) { return childResult, nil }) - + // Create workflow that enqueues another workflow - queueWfThatEnqueues := WithWorkflow(executor, func(ctx context.Context, input string) (string, error) { + queueWfThatEnqueues := RegisterWorkflow(executor, func(ctx context.Context, input string) (string, error) { // Enqueue another workflow to the same queue enqueuedHandle, err := queueWf(ctx, input+"-enqueued", WithQueue(queue.name)) if err != nil { @@ -88,8 +87,8 @@ func TestWorkflowQueues(t *testing.T) { return enqueuedResult, nil }) - - enqueueWorkflowDLQ := WithWorkflow(executor, func(ctx context.Context, input string) (string, error) { + + enqueueWorkflowDLQ := RegisterWorkflow(executor, func(ctx context.Context, input string) (string, error) { dlqStartEvent.Set() dlqCompleteEvent.Wait() return input, nil @@ -252,24 +251,23 @@ var ( recoveryStepCounter = 0 recoveryStepEvents = make([]*Event, 5) // 5 queued steps recoveryEvent = NewEvent() - ) func TestQueueRecovery(t *testing.T) { executor := setupDBOS(t) - + // Create workflows with executor var recoveryStepWorkflow func(context.Context, int, ...workflowOption) (WorkflowHandle[int], error) var recoveryWorkflow func(context.Context, string, ...workflowOption) (WorkflowHandle[[]int], error) - - recoveryStepWorkflow = WithWorkflow(executor, func(ctx context.Context, i int) (int, error) { + + recoveryStepWorkflow = RegisterWorkflow(executor, func(ctx context.Context, i int) (int, error) { recoveryStepCounter++ recoveryStepEvents[i].Set() recoveryEvent.Wait() return i, nil }) - recoveryWorkflow = WithWorkflow(executor, func(ctx context.Context, input string) ([]int, error) { + recoveryWorkflow = RegisterWorkflow(executor, func(ctx context.Context, input string) ([]int, error) { handles := make([]WorkflowHandle[int], 0, 5) // 5 queued steps for i := range 5 { handle, err := recoveryStepWorkflow(ctx, i, WithQueue(recoveryQueue.name)) @@ -382,17 +380,17 @@ func TestQueueRecovery(t *testing.T) { } var ( - globalConcurrencyQueue = NewWorkflowQueue("test-global-concurrency-queue", WithGlobalConcurrency(1)) - workflowEvent1 = NewEvent() - workflowEvent2 = NewEvent() - workflowDoneEvent = NewEvent() + globalConcurrencyQueue = NewWorkflowQueue("test-global-concurrency-queue", WithGlobalConcurrency(1)) + workflowEvent1 = NewEvent() + workflowEvent2 = NewEvent() + workflowDoneEvent = NewEvent() ) func TestGlobalConcurrency(t *testing.T) { executor := setupDBOS(t) - + // Create workflow with executor - globalConcurrencyWorkflow := WithWorkflow(executor, func(ctx context.Context, input string) (string, error) { + globalConcurrencyWorkflow := RegisterWorkflow(executor, func(ctx context.Context, input string) (string, error) { switch input { case "workflow1": workflowEvent1.Set() @@ -474,9 +472,9 @@ var ( func TestWorkerConcurrency(t *testing.T) { executor := setupDBOS(t) - + // Create workflow with executor - blockingWf := WithWorkflow(executor, func(ctx context.Context, i int) (int, error) { + blockingWf := RegisterWorkflow(executor, func(ctx context.Context, i int) (int, error) { // Simulate a blocking operation startEvents[i].Set() completeEvents[i].Wait() @@ -623,14 +621,14 @@ var ( func TestWorkerConcurrencyXRecovery(t *testing.T) { executor := setupDBOS(t) - + // Create workflows with executor - workerConcurrencyRecoveryBlockingWf1 := WithWorkflow(executor, func(ctx context.Context, input string) (string, error) { + workerConcurrencyRecoveryBlockingWf1 := RegisterWorkflow(executor, func(ctx context.Context, input string) (string, error) { workerConcurrencyRecoveryStartEvent1.Set() workerConcurrencyRecoveryCompleteEvent1.Wait() return input, nil }) - workerConcurrencyRecoveryBlockingWf2 := WithWorkflow(executor, func(ctx context.Context, input string) (string, error) { + workerConcurrencyRecoveryBlockingWf2 := RegisterWorkflow(executor, func(ctx context.Context, input string) (string, error) { workerConcurrencyRecoveryStartEvent2.Set() workerConcurrencyRecoveryCompleteEvent2.Wait() return input, nil @@ -739,9 +737,9 @@ func rateLimiterTestWorkflow(ctx context.Context, _ string) (time.Time, error) { func TestQueueRateLimiter(t *testing.T) { executor := setupDBOS(t) - + // Create workflow with executor - rateLimiterWorkflow := WithWorkflow(executor, rateLimiterTestWorkflow) + rateLimiterWorkflow := RegisterWorkflow(executor, rateLimiterTestWorkflow) limit := 5 period := 1.8 diff --git a/dbos/recovery.go b/dbos/recovery.go index 07792a09..a529311c 100644 --- a/dbos/recovery.go +++ b/dbos/recovery.go @@ -1,17 +1,16 @@ package dbos import ( - "context" "strings" ) -func recoverPendingWorkflows(ctx context.Context, executorIDs []string) ([]WorkflowHandle[any], error) { +func recoverPendingWorkflows(dbosCtx *dbosContext, executorIDs []string) ([]WorkflowHandle[any], error) { workflowHandles := make([]WorkflowHandle[any], 0) // List pending workflows for the executors - pendingWorkflows, err := dbos.systemDB.ListWorkflows(ctx, listWorkflowsDBInput{ + pendingWorkflows, err := dbosCtx.systemDB.ListWorkflows(dbosCtx.GetContext(), listWorkflowsDBInput{ status: []WorkflowStatusType{WorkflowStatusPending}, executorIDs: executorIDs, - applicationVersion: dbos.applicationVersion, + applicationVersion: dbosCtx.applicationVersion, }) if err != nil { return nil, err @@ -27,18 +26,18 @@ func recoverPendingWorkflows(ctx context.Context, executorIDs []string) ([]Workf // fmt.Println("Recovering workflow:", workflow.ID, "Name:", workflow.Name, "Input:", workflow.Input, "QueueName:", workflow.QueueName) if workflow.QueueName != "" { - cleared, err := dbos.systemDB.ClearQueueAssignment(ctx, workflow.ID) + cleared, err := dbosCtx.systemDB.ClearQueueAssignment(dbosCtx.GetContext(), workflow.ID) if err != nil { getLogger().Error("Error clearing queue assignment for workflow", "workflow_id", workflow.ID, "name", workflow.Name, "error", err) continue } if cleared { - workflowHandles = append(workflowHandles, &workflowPollingHandle[any]{workflowID: workflow.ID}) + workflowHandles = append(workflowHandles, &workflowPollingHandle[any]{workflowID: workflow.ID, systemDB: dbosCtx.systemDB}) } continue } - registeredWorkflow, exists := dbos.workflowRegistry[workflow.Name] + registeredWorkflow, exists := dbosCtx.workflowRegistry[workflow.Name] if !exists { getLogger().Error("Workflow function not found in registry", "workflow_id", workflow.ID, "name", workflow.Name) continue @@ -56,7 +55,9 @@ func recoverPendingWorkflows(ctx context.Context, executorIDs []string) ([]Workf opts = append(opts, WithDeadline(workflow.Deadline)) } - handle, err := registeredWorkflow.wrappedFunction(ctx, workflow.Input, opts...) + // Create a workflow context from the executor context + workflowCtx := dbosCtx.WithValue(dbosCtx.GetContext(), nil) + handle, err := registeredWorkflow.wrappedFunction(workflowCtx, workflow.Input, opts...) if err != nil { return nil, err } diff --git a/dbos/serialization_test.go b/dbos/serialization_test.go index 8450b0ae..90332191 100644 --- a/dbos/serialization_test.go +++ b/dbos/serialization_test.go @@ -23,7 +23,7 @@ func encodingStepBuiltinTypes(_ context.Context, input int) (int, error) { } func encodingWorkflowBuiltinTypes(ctx context.Context, input string) (string, error) { - stepResult, err := RunAsStep(ctx, encodingStepBuiltinTypes, 123) + stepResult, err := RunAsStep(ctx, dbos, encodingStepBuiltinTypes, 123) return fmt.Sprintf("%d", stepResult), fmt.Errorf("workflow error: %v", err) } @@ -49,7 +49,7 @@ type SimpleStruct struct { } func encodingWorkflowStruct(ctx context.Context, input WorkflowInputStruct) (StepOutputStruct, error) { - return RunAsStep(ctx, encodingStepStruct, StepInputStruct{ + return RunAsStep(ctx, dbos, encodingStepStruct, StepInputStruct{ A: input.A, B: fmt.Sprintf("%d", input.B), }) @@ -66,8 +66,8 @@ func TestWorkflowEncoding(t *testing.T) { executor := setupDBOS(t) // Create workflows with executor - builtinWf := WithWorkflow(executor, encodingWorkflowBuiltinTypes) - structWf := WithWorkflow(executor, encodingWorkflowStruct) + builtinWf := RegisterWorkflow(executor, encodingWorkflowBuiltinTypes) + structWf := RegisterWorkflow(executor, encodingWorkflowStruct) t.Run("BuiltinTypes", func(t *testing.T) { // Test a workflow that uses a built-in type (string) @@ -86,7 +86,7 @@ func TestWorkflowEncoding(t *testing.T) { } // Test result from polling handle - retrieveHandler, err := RetrieveWorkflow[string](directHandle.GetWorkflowID()) + retrieveHandler, err := RetrieveWorkflow[string](dbos, directHandle.GetWorkflowID()) if err != nil { t.Fatalf("failed to retrieve workflow: %v", err) } @@ -194,7 +194,7 @@ func TestWorkflowEncoding(t *testing.T) { } // Test result from polling handle - retrieveHandler, err := RetrieveWorkflow[StepOutputStruct](directHandle.GetWorkflowID()) + retrieveHandler, err := RetrieveWorkflow[StepOutputStruct](dbos, directHandle.GetWorkflowID()) if err != nil { t.Fatalf("failed to retrieve step workflow: %v", err) } @@ -313,7 +313,7 @@ func setEventUserDefinedTypeWorkflow(ctx context.Context, input string) (string, }, } - err := SetEvent(ctx, WorkflowSetEventInput[UserDefinedEventData]{Key: input, Message: eventData}) + err := SetEvent(ctx, dbos, WorkflowSetEventInput[UserDefinedEventData]{Key: input, Message: eventData}) if err != nil { return "", err } @@ -324,7 +324,7 @@ func TestSetEventSerialize(t *testing.T) { executor := setupDBOS(t) // Create workflow with executor - setEventUserDefinedTypeWf := WithWorkflow(executor, setEventUserDefinedTypeWorkflow) + setEventUserDefinedTypeWf := RegisterWorkflow(executor, setEventUserDefinedTypeWorkflow) t.Run("SetEventUserDefinedType", func(t *testing.T) { // Start a workflow that sets an event with a user-defined type @@ -343,7 +343,7 @@ func TestSetEventSerialize(t *testing.T) { } // Retrieve the event to verify it was properly serialized and can be deserialized - retrievedEvent, err := GetEvent[UserDefinedEventData](context.Background(), WorkflowGetEventInput{ + retrievedEvent, err := GetEvent[UserDefinedEventData](context.Background(), dbos, WorkflowGetEventInput{ TargetWorkflowID: setHandle.GetWorkflowID(), Key: "user-defined-key", Timeout: 3 * time.Second, @@ -390,7 +390,7 @@ func sendUserDefinedTypeWorkflow(ctx context.Context, destinationID string) (str // Send should automatically register this type with gob // Note the explicit type parameter since compiler cannot infer UserDefinedEventData from string input - err := Send(ctx, WorkflowSendInput[UserDefinedEventData]{ + err := Send(ctx, dbos, WorkflowSendInput[UserDefinedEventData]{ DestinationID: destinationID, Topic: "user-defined-topic", Message: sendData, @@ -403,7 +403,7 @@ func sendUserDefinedTypeWorkflow(ctx context.Context, destinationID string) (str func recvUserDefinedTypeWorkflow(ctx context.Context, input string) (UserDefinedEventData, error) { // Receive the user-defined type message - result, err := Recv[UserDefinedEventData](ctx, WorkflowRecvInput{ + result, err := Recv[UserDefinedEventData](ctx, dbos, WorkflowRecvInput{ Topic: "user-defined-topic", Timeout: 3 * time.Second, }) @@ -414,8 +414,8 @@ func TestSendSerialize(t *testing.T) { executor := setupDBOS(t) // Create workflows with executor - sendUserDefinedTypeWf := WithWorkflow(executor, sendUserDefinedTypeWorkflow) - recvUserDefinedTypeWf := WithWorkflow(executor, recvUserDefinedTypeWorkflow) + sendUserDefinedTypeWf := RegisterWorkflow(executor, sendUserDefinedTypeWorkflow) + recvUserDefinedTypeWf := RegisterWorkflow(executor, recvUserDefinedTypeWorkflow) t.Run("SendUserDefinedType", func(t *testing.T) { // Start a receiver workflow first diff --git a/dbos/system_database.go b/dbos/system_database.go index 1ccba59f..46f964d9 100644 --- a/dbos/system_database.go +++ b/dbos/system_database.go @@ -45,16 +45,16 @@ type SystemDatabase interface { GetWorkflowSteps(ctx context.Context, workflowID string) ([]StepInfo, error) // Communication (special steps) - Send(ctx context.Context, input workflowSendInputInternal) error + Send(ctx context.Context, input WorkflowSendInputInternal) error Recv(ctx context.Context, input WorkflowRecvInput) (any, error) - SetEvent(ctx context.Context, input workflowSetEventInputInternal) error + SetEvent(ctx context.Context, input WorkflowSetEventInputInternal) error GetEvent(ctx context.Context, input WorkflowGetEventInput) (any, error) // Timers (special steps) Sleep(ctx context.Context, duration time.Duration) (time.Duration, error) // Queues - DequeueWorkflows(ctx context.Context, queue WorkflowQueue) ([]dequeuedWorkflow, error) + DequeueWorkflows(ctx context.Context, queue WorkflowQueue, executorID, applicationVersion string) ([]dequeuedWorkflow, error) ClearQueueAssignment(ctx context.Context, workflowID string) (bool, error) } @@ -239,17 +239,18 @@ type insertWorkflowResult struct { name string queueName *string workflowDeadlineEpochMs *int64 + tx pgx.Tx } type insertWorkflowStatusDBInput struct { status WorkflowStatus maxRetries int - tx pgx.Tx } func (s *systemDatabase) InsertWorkflowStatus(ctx context.Context, input insertWorkflowStatusDBInput) (*insertWorkflowResult, error) { - if input.tx == nil { - return nil, errors.New("transaction is required for InsertWorkflowStatus") + tx, err := s.pool.Begin(ctx) + if err != nil { + return nil, fmt.Errorf("failed to begin transaction: %w", err) } // Set default values @@ -312,8 +313,10 @@ func (s *systemDatabase) InsertWorkflowStatus(ctx context.Context, input insertW END RETURNING recovery_attempts, status, name, queue_name, workflow_deadline_epoch_ms` - var result insertWorkflowResult - err = input.tx.QueryRow(ctx, query, + result := insertWorkflowResult{ + tx: tx, + } + err = tx.QueryRow(ctx, query, input.status.ID, input.status.Status, input.status.Name, @@ -362,7 +365,7 @@ func (s *systemDatabase) InsertWorkflowStatus(ctx context.Context, input insertW SET status = $1, deduplication_id = NULL, started_at_epoch_ms = NULL, queue_name = NULL WHERE workflow_uuid = $2 AND status = $3` - _, err = input.tx.Exec(ctx, dlqQuery, + _, err = tx.Exec(ctx, dlqQuery, WorkflowStatusRetriesExceeded, input.status.ID, WorkflowStatusPending) @@ -372,7 +375,7 @@ func (s *systemDatabase) InsertWorkflowStatus(ctx context.Context, input insertW } // Commit the transaction before throwing the error - if err := input.tx.Commit(ctx); err != nil { + if err := tx.Commit(ctx); err != nil { return nil, fmt.Errorf("failed to commit transaction after marking workflow as RETRIES_EXCEEDED: %w", err) } @@ -1115,16 +1118,16 @@ func (s *systemDatabase) notificationListenerLoop(ctx context.Context) { const _DBOS_NULL_TOPIC = "__null__topic__" -type workflowSendInputInternal struct { - destinationID string - message any - topic string +type WorkflowSendInputInternal struct { + DestinationID string + Message any + Topic string } // Send is a special type of step that sends a message to another workflow. // Can be called both within a workflow (as a step) or outside a workflow (directly). // When called within a workflow: durability and the function run in the same transaction, and we forbid nested step execution -func (s *systemDatabase) Send(ctx context.Context, input workflowSendInputInternal) error { +func (s *systemDatabase) Send(ctx context.Context, input WorkflowSendInputInternal) error { functionName := "DBOS.send" // Get workflow state from context (optional for Send as we can send from outside a workflow) @@ -1166,22 +1169,22 @@ func (s *systemDatabase) Send(ctx context.Context, input workflowSendInputIntern // Set default topic if not provided topic := _DBOS_NULL_TOPIC - if len(input.topic) > 0 { - topic = input.topic + if len(input.Topic) > 0 { + topic = input.Topic } // Serialize the message. It must have been registered with encoding/gob by the user if not a basic type. - messageString, err := serialize(input.message) + messageString, err := serialize(input.Message) if err != nil { return fmt.Errorf("failed to serialize message: %w", err) } insertQuery := `INSERT INTO dbos.notifications (destination_uuid, topic, message) VALUES ($1, $2, $3)` - _, err = tx.Exec(ctx, insertQuery, input.destinationID, topic, messageString) + _, err = tx.Exec(ctx, insertQuery, input.DestinationID, topic, messageString) if err != nil { // Check for foreign key violation (destination workflow doesn't exist) if pgErr, ok := err.(*pgconn.PgError); ok && pgErr.Code == "23503" { - return newNonExistentWorkflowError(input.destinationID) + return newNonExistentWorkflowError(input.DestinationID) } return fmt.Errorf("failed to insert notification: %w", err) } @@ -1357,12 +1360,12 @@ func (s *systemDatabase) Recv(ctx context.Context, input WorkflowRecvInput) (any return message, nil } -type workflowSetEventInputInternal struct { - key string - message any +type WorkflowSetEventInputInternal struct { + Key string + Message any } -func (s *systemDatabase) SetEvent(ctx context.Context, input workflowSetEventInputInternal) error { +func (s *systemDatabase) SetEvent(ctx context.Context, input WorkflowSetEventInputInternal) error { functionName := "DBOS.setEvent" // Get workflow state from context @@ -1400,7 +1403,7 @@ func (s *systemDatabase) SetEvent(ctx context.Context, input workflowSetEventInp } // Serialize the message. It must have been registered with encoding/gob by the user if not a basic type. - messageString, err := serialize(input.message) + messageString, err := serialize(input.Message) if err != nil { return fmt.Errorf("failed to serialize message: %w", err) } @@ -1411,7 +1414,7 @@ func (s *systemDatabase) SetEvent(ctx context.Context, input workflowSetEventInp ON CONFLICT (workflow_uuid, key) DO UPDATE SET value = EXCLUDED.value` - _, err = tx.Exec(ctx, insertQuery, wfState.workflowID, input.key, messageString) + _, err = tx.Exec(ctx, insertQuery, wfState.workflowID, input.Key, messageString) if err != nil { return fmt.Errorf("failed to insert/update workflow event: %w", err) } @@ -1567,7 +1570,8 @@ type dequeuedWorkflow struct { input string } -func (s *systemDatabase) DequeueWorkflows(ctx context.Context, queue WorkflowQueue) ([]dequeuedWorkflow, error) { +// TODO input struct +func (s *systemDatabase) DequeueWorkflows(ctx context.Context, queue WorkflowQueue, executorID, applicationVersion string) ([]dequeuedWorkflow, error) { // Begin transaction with snapshot isolation tx, err := s.pool.Begin(ctx) if err != nil { @@ -1638,7 +1642,7 @@ func (s *systemDatabase) DequeueWorkflows(ctx context.Context, queue WorkflowQue pendingWorkflowsDict[executorIDRow] = taskCount } - localPendingWorkflows := pendingWorkflowsDict[dbos.executorID] + localPendingWorkflows := pendingWorkflowsDict[executorID] // Check worker concurrency limit if queue.workerConcurrency != nil { @@ -1705,7 +1709,7 @@ func (s *systemDatabase) DequeueWorkflows(ctx context.Context, queue WorkflowQue } // Execute the query to get workflow IDs - rows, err := tx.Query(ctx, query, queue.name, WorkflowStatusEnqueued, dbos.applicationVersion) + rows, err := tx.Query(ctx, query, queue.name, WorkflowStatusEnqueued, applicationVersion) if err != nil { return nil, fmt.Errorf("failed to query enqueued workflows: %w", err) } @@ -1755,8 +1759,8 @@ func (s *systemDatabase) DequeueWorkflows(ctx context.Context, queue WorkflowQue var inputString *string err := tx.QueryRow(ctx, updateQuery, WorkflowStatusPending, - dbos.applicationVersion, - dbos.executorID, + applicationVersion, + executorID, startTimeMs, id).Scan(&retWorkflow.name, &inputString) diff --git a/dbos/utils_test.go b/dbos/utils_test.go index 609cf9ba..610e45e1 100644 --- a/dbos/utils_test.go +++ b/dbos/utils_test.go @@ -25,7 +25,7 @@ func getDatabaseURL(t *testing.T) string { } /* Test database setup */ -func setupDBOS(t *testing.T) DBOSExecutor { +func setupDBOS(t *testing.T) DBOSContext { t.Helper() databaseURL := getDatabaseURL(t) @@ -54,7 +54,7 @@ func setupDBOS(t *testing.T) DBOSExecutor { t.Fatalf("failed to drop test database: %v", err) } - executor, err := Initialize(Config{ + executor, err := NewDBOSContext(Config{ DatabaseURL: databaseURL, AppName: "test-app", }) @@ -74,9 +74,8 @@ func setupDBOS(t *testing.T) DBOSExecutor { // Register cleanup to run after test completes t.Cleanup(func() { fmt.Println("Cleaning up DBOS instance...") - if dbos != nil { - dbos.Shutdown() - dbos = nil + if executor != nil { + executor.Shutdown() } }) @@ -120,27 +119,31 @@ func (e *Event) Clear() { /* Helpers */ // stopQueueRunner stops the queue runner for testing purposes -func stopQueueRunner() { - if dbos != nil && dbos.queueRunnerCancelFunc != nil { - dbos.queueRunnerCancelFunc() - // Wait for queue runner to finish - <-dbos.queueRunnerDone +func stopQueueRunner(executor DBOSContext) { + if executor != nil { + exec := executor.(*dbosContext) + if exec.queueRunnerCancelFunc != nil { + exec.queueRunnerCancelFunc() + // Wait for queue runner to finish + <-exec.queueRunnerDone + } } } // restartQueueRunner restarts the queue runner for testing purposes -func restartQueueRunner() { - if dbos != nil { +func restartQueueRunner(executor DBOSContext) { + if executor != nil { + exec := executor.(*dbosContext) // Create new context and cancel function ctx, cancel := context.WithCancel(context.Background()) - dbos.queueRunnerCtx = ctx - dbos.queueRunnerCancelFunc = cancel - dbos.queueRunnerDone = make(chan struct{}) + exec.queueRunnerCtx = ctx + exec.queueRunnerCancelFunc = cancel + exec.queueRunnerDone = make(chan struct{}) // Start the queue runner in a goroutine go func() { - defer close(dbos.queueRunnerDone) - queueRunner(ctx) + defer close(exec.queueRunnerDone) + queueRunner(ctx, exec) }() } } @@ -157,12 +160,13 @@ func equal(a, b []int) bool { return true } -func queueEntriesAreCleanedUp() bool { +func queueEntriesAreCleanedUp(executor DBOSContext) bool { maxTries := 10 success := false for range maxTries { // Begin transaction - tx, err := dbos.systemDB.(*systemDatabase).pool.Begin(context.Background()) + exec := executor.(*dbosContext) + tx, err := exec.systemDB.(*systemDatabase).pool.Begin(context.Background()) if err != nil { return false } diff --git a/dbos/workflow.go b/dbos/workflow.go index ec1424aa..b4501021 100644 --- a/dbos/workflow.go +++ b/dbos/workflow.go @@ -54,7 +54,6 @@ type WorkflowStatus struct { } // workflowState holds the runtime state for a workflow execution -// TODO: this should be an internal type. Workflows should have aptly named getters to access the state type workflowState struct { workflowID string stepCounter int @@ -80,13 +79,14 @@ type workflowOutcome[R any] struct { type WorkflowHandle[R any] interface { GetResult(ctx context.Context) (R, error) GetStatus() (WorkflowStatus, error) - GetWorkflowID() string // XXX we could have a base struct with GetWorkflowID and then embed it in the implementations + GetWorkflowID() string } // workflowHandle is a concrete implementation of WorkflowHandle type workflowHandle[R any] struct { workflowID string outcomeChan chan workflowOutcome[R] + systemDB SystemDatabase } // GetResult waits for the workflow to complete and returns the result @@ -111,7 +111,7 @@ func (h *workflowHandle[R]) GetResult(ctx context.Context) (R, error) { output: encodedOutput, err: outcome.err, } - recordResultErr := dbos.systemDB.RecordChildGetResult(ctx, recordGetResultInput) + recordResultErr := h.systemDB.RecordChildGetResult(ctx, recordGetResultInput) if recordResultErr != nil { getLogger().Error("failed to record get result", "error", recordResultErr) return *new(R), newWorkflowExecutionError(parentWorkflowState.workflowID, fmt.Sprintf("recording child workflow result: %v", recordResultErr)) @@ -123,7 +123,7 @@ func (h *workflowHandle[R]) GetResult(ctx context.Context) (R, error) { // GetStatus returns the current status of the workflow from the database func (h *workflowHandle[R]) GetStatus() (WorkflowStatus, error) { ctx := context.Background() - workflowStatuses, err := dbos.systemDB.ListWorkflows(ctx, listWorkflowsDBInput{ + workflowStatuses, err := h.systemDB.ListWorkflows(ctx, listWorkflowsDBInput{ workflowIDs: []string{h.workflowID}, }) if err != nil { @@ -141,10 +141,11 @@ func (h *workflowHandle[R]) GetWorkflowID() string { type workflowPollingHandle[R any] struct { workflowID string + systemDB SystemDatabase } func (h *workflowPollingHandle[R]) GetResult(ctx context.Context) (R, error) { - result, err := dbos.systemDB.AwaitWorkflowResult(ctx, h.workflowID) + result, err := h.systemDB.AwaitWorkflowResult(ctx, h.workflowID) if result != nil { typedResult, ok := result.(R) if !ok { @@ -166,7 +167,7 @@ func (h *workflowPollingHandle[R]) GetResult(ctx context.Context) (R, error) { output: encodedOutput, err: err, } - recordResultErr := dbos.systemDB.RecordChildGetResult(ctx, recordGetResultInput) + recordResultErr := h.systemDB.RecordChildGetResult(ctx, recordGetResultInput) if recordResultErr != nil { // XXX do we want to fail this? getLogger().Error("failed to record get result", "error", recordResultErr) @@ -180,7 +181,7 @@ func (h *workflowPollingHandle[R]) GetResult(ctx context.Context) (R, error) { // GetStatus returns the current status of the workflow from the database func (h *workflowPollingHandle[R]) GetStatus() (WorkflowStatus, error) { ctx := context.Background() - workflowStatuses, err := dbos.systemDB.ListWorkflows(ctx, listWorkflowsDBInput{ + workflowStatuses, err := h.systemDB.ListWorkflows(ctx, listWorkflowsDBInput{ workflowIDs: []string{h.workflowID}, }) if err != nil { @@ -199,7 +200,7 @@ func (h *workflowPollingHandle[R]) GetWorkflowID() string { /**********************************/ /******* WORKFLOW REGISTRY *******/ /**********************************/ -type typedErasedWorkflowWrapperFunc func(ctx context.Context, input any, opts ...workflowOption) (WorkflowHandle[any], error) +type typedErasedWorkflowWrapperFunc func(ctx DBOSContext, input any, opts ...workflowOption) (WorkflowHandle[any], error) type workflowRegistryEntry struct { wrappedFunction typedErasedWorkflowWrapperFunc @@ -207,7 +208,7 @@ type workflowRegistryEntry struct { } // Register adds a workflow function to the registry (thread-safe, only once per name) -func (e *executor) RegisterWorkflow(fqn string, fn typedErasedWorkflowWrapperFunc, maxRetries int) { +func (e *dbosContext) RegisterWorkflow(fqn string, fn typedErasedWorkflowWrapperFunc, maxRetries int) { e.workflowRegMutex.Lock() defer e.workflowRegMutex.Unlock() @@ -222,7 +223,7 @@ func (e *executor) RegisterWorkflow(fqn string, fn typedErasedWorkflowWrapperFun } } -func (e *executor) RegisterScheduledWorkflow(fqn string, fn typedErasedWorkflowWrapperFunc, cronSchedule string, maxRetries int) { +func (e *dbosContext) RegisterScheduledWorkflow(fqn string, fn typedErasedWorkflowWrapperFunc, cronSchedule string, maxRetries int) { e.GetWorkflowScheduler().Start() var entryID cron.EntryID entryID, err := e.GetWorkflowScheduler().AddFunc(cronSchedule, func() { @@ -238,7 +239,7 @@ func (e *executor) RegisterScheduledWorkflow(fqn string, fn typedErasedWorkflowW scheduledTime = entry.Next } wfID := fmt.Sprintf("sched-%s-%s", fqn, scheduledTime) // XXX we can rethink the format - fn(context.Background(), scheduledTime, WithWorkflowID(wfID), WithQueue(_DBOS_INTERNAL_QUEUE_NAME)) + fn(e, scheduledTime, WithWorkflowID(wfID), WithQueue(_DBOS_INTERNAL_QUEUE_NAME)) }) if err != nil { panic(fmt.Sprintf("failed to register scheduled workflow: %v", err)) @@ -270,10 +271,11 @@ func WithSchedule(schedule string) workflowRegistrationOption { } } -func WithWorkflow[P any, R any](dbosExecutor DBOSExecutor, fn WorkflowFunc[P, R], opts ...workflowRegistrationOption) WorkflowWrapperFunc[P, R] { - if dbosExecutor == nil { +// TODO split RegisterWorkflow and RegisterScheduledWorkflow +func RegisterWorkflow[P any, R any](dbosCtx DBOSContext, fn WorkflowFunc[P, R], opts ...workflowRegistrationOption) WorkflowWrapperFunc[P, R] { + if dbosCtx == nil { // TODO: consider panic here - getLogger().Error("Provide a valid DBOS executor instance") + getLogger().Error("Provide a valid DBOS context") return nil } @@ -296,13 +298,13 @@ func WithWorkflow[P any, R any](dbosExecutor DBOSExecutor, fn WorkflowFunc[P, R] gob.Register(r) // Wrap the function in a durable workflow - wrappedFunction := WorkflowWrapperFunc[P, R](func(ctx context.Context, workflowInput P, opts ...workflowOption) (WorkflowHandle[R], error) { + wrappedFunction := WorkflowWrapperFunc[P, R](func(ctx DBOSContext, workflowInput P, opts ...workflowOption) (WorkflowHandle[R], error) { opts = append(opts, WithWorkflowMaxRetries(registrationParams.maxRetries)) - return runAsWorkflow(ctx, dbosExecutor, fn, workflowInput, opts...) + return runAsWorkflow(ctx, fn, workflowInput, opts...) }) // Register a type-erased version of the durable workflow for recovery - typeErasedWrapper := func(ctx context.Context, input any, opts ...workflowOption) (WorkflowHandle[any], error) { + typeErasedWrapper := func(ctx DBOSContext, input any, opts ...workflowOption) (WorkflowHandle[any], error) { typedInput, ok := input.(P) if !ok { return nil, newWorkflowUnexpectedInputType(fqn, fmt.Sprintf("%T", typedInput), fmt.Sprintf("%T", input)) @@ -312,16 +314,16 @@ func WithWorkflow[P any, R any](dbosExecutor DBOSExecutor, fn WorkflowFunc[P, R] if err != nil { return nil, err } - return &workflowPollingHandle[any]{workflowID: handle.GetWorkflowID()}, nil + return &workflowPollingHandle[any]{workflowID: handle.GetWorkflowID(), systemDB: ctx.(*dbosContext).systemDB}, nil } - dbosExecutor.RegisterWorkflow(fqn, typeErasedWrapper, registrationParams.maxRetries) + dbosCtx.RegisterWorkflow(fqn, typeErasedWrapper, registrationParams.maxRetries) // If this is a scheduled workflow, register a cron job if registrationParams.cronSchedule != "" { if reflect.TypeOf(p) != reflect.TypeOf(time.Time{}) { panic(fmt.Sprintf("scheduled workflow function must accept a time.Time as input, got %T", p)) } - dbosExecutor.RegisterScheduledWorkflow(fqn, typeErasedWrapper, registrationParams.cronSchedule, registrationParams.maxRetries) + dbosCtx.RegisterScheduledWorkflow(fqn, typeErasedWrapper, registrationParams.cronSchedule, registrationParams.maxRetries) } return wrappedFunction @@ -335,8 +337,8 @@ type contextKey string const workflowStateKey contextKey = "workflowState" -type WorkflowFunc[P any, R any] func(ctx context.Context, input P) (R, error) -type WorkflowWrapperFunc[P any, R any] func(ctx context.Context, input P, opts ...workflowOption) (WorkflowHandle[R], error) +type WorkflowFunc[P any, R any] func(dbosCtx DBOSContext, input P) (R, error) +type WorkflowWrapperFunc[P any, R any] func(dbosCtx DBOSContext, input P, opts ...workflowOption) (WorkflowHandle[R], error) type workflowParams struct { workflowID string @@ -385,20 +387,17 @@ func WithWorkflowMaxRetries(maxRetries int) workflowOption { } } -func runAsWorkflow[P any, R any](ctx context.Context, dbosExecutor DBOSExecutor, fn WorkflowFunc[P, R], input P, opts ...workflowOption) (WorkflowHandle[R], error) { +func runAsWorkflow[P any, R any](dbosCtx DBOSContext, fn WorkflowFunc[P, R], input P, opts ...workflowOption) (WorkflowHandle[R], error) { // Apply options to build params params := workflowParams{ - applicationVersion: dbosExecutor.GetApplicationVersion(), + applicationVersion: dbosCtx.GetApplicationVersion(), } for _, opt := range opts { opt(¶ms) } - // First, create a context for the workflow - dbosWorkflowContext := context.Background() - // Check if we are within a workflow (and thus a child workflow) - parentWorkflowState, ok := ctx.Value(workflowStateKey).(*workflowState) + parentWorkflowState, ok := dbosCtx.Value(workflowStateKey).(*workflowState) isChildWorkflow := ok && parentWorkflowState != nil // TODO Check if cancelled @@ -418,12 +417,12 @@ func runAsWorkflow[P any, R any](ctx context.Context, dbosExecutor DBOSExecutor, // If this is a child workflow that has already been recorded in operations_output, return directly a polling handle if isChildWorkflow { - childWorkflowID, err := dbos.systemDB.CheckChildWorkflow(dbosWorkflowContext, parentWorkflowState.workflowID, parentWorkflowState.stepCounter) + childWorkflowID, err := dbosCtx.CheckChildWorkflow(parentWorkflowState.workflowID, parentWorkflowState.stepCounter) if err != nil { return nil, newWorkflowExecutionError(parentWorkflowState.workflowID, fmt.Sprintf("checking child workflow: %v", err)) } if childWorkflowID != nil { - return &workflowPollingHandle[R]{workflowID: *childWorkflowID}, nil + return &workflowPollingHandle[R]{workflowID: *childWorkflowID, systemDB: dbosCtx.(*dbosContext).systemDB}, nil } } @@ -437,42 +436,31 @@ func runAsWorkflow[P any, R any](ctx context.Context, dbosExecutor DBOSExecutor, workflowStatus := WorkflowStatus{ Name: runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name(), // TODO factor out somewhere else so we dont' have to reflect here ApplicationVersion: params.applicationVersion, - ExecutorID: dbos.executorID, + ExecutorID: dbosCtx.GetExecutorID(), Status: status, ID: workflowID, CreatedAt: time.Now(), Deadline: params.deadline, // TODO compute the deadline based on the timeout Timeout: params.timeout, Input: input, - ApplicationID: dbos.applicationID, + ApplicationID: dbosCtx.GetApplicationID(), QueueName: params.queueName, } // Init status and record child workflow relationship in a single transaction - tx, err := dbos.systemDB.(*systemDatabase).pool.Begin(dbosWorkflowContext) - if err != nil { - return nil, newWorkflowExecutionError(workflowID, fmt.Sprintf("failed to begin transaction: %v", err)) - } - defer tx.Rollback(dbosWorkflowContext) // Rollback if not committed - - // Insert workflow status with transaction - insertInput := insertWorkflowStatusDBInput{ - status: workflowStatus, - maxRetries: params.maxRetries, - tx: tx, - } - insertStatusResult, err := dbos.systemDB.InsertWorkflowStatus(dbosWorkflowContext, insertInput) + insertStatusResult, err := dbosCtx.InsertWorkflowStatus(workflowStatus, params.maxRetries) if err != nil { return nil, err } + defer insertStatusResult.tx.Rollback(dbosCtx.GetContext()) // Rollback if not committed // Return a polling handle if: we are enqueueing, the workflow is already in a terminal state (success or error), if len(params.queueName) > 0 || insertStatusResult.status == WorkflowStatusSuccess || insertStatusResult.status == WorkflowStatusError { // Commit the transaction to update the number of attempts and/or enact the enqueue - if err := tx.Commit(dbosWorkflowContext); err != nil { + if err := insertStatusResult.tx.Commit(dbosCtx.GetContext()); err != nil { return nil, newWorkflowExecutionError(workflowID, fmt.Sprintf("failed to commit transaction: %v", err)) } - return &workflowPollingHandle[R]{workflowID: workflowStatus.ID}, nil + return &workflowPollingHandle[R]{workflowID: workflowStatus.ID, systemDB: dbosCtx.(*dbosContext).systemDB}, nil } // Record child workflow relationship if this is a child workflow @@ -484,19 +472,14 @@ func runAsWorkflow[P any, R any](ctx context.Context, dbosExecutor DBOSExecutor, childWorkflowID: workflowStatus.ID, stepName: runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name(), // Will need to test this stepID: stepID, - tx: tx, + tx: insertStatusResult.tx, } - err = dbos.systemDB.RecordChildWorkflow(dbosWorkflowContext, childInput) + err = dbosCtx.RecordChildWorkflow(childInput) if err != nil { return nil, newWorkflowExecutionError(parentWorkflowState.workflowID, fmt.Sprintf("recording child workflow: %v", err)) } } - // Commit the transaction - if err := tx.Commit(dbosWorkflowContext); err != nil { - return nil, newWorkflowExecutionError(workflowID, fmt.Sprintf("failed to commit transaction: %v", err)) - } - // Channel to receive the outcome from the goroutine // The buffer size of 1 allows the goroutine to send the outcome without blocking // In addition it allows the channel to be garbage collected @@ -506,6 +489,7 @@ func runAsWorkflow[P any, R any](ctx context.Context, dbosExecutor DBOSExecutor, handle := &workflowHandle[R]{ workflowID: workflowStatus.ID, outcomeChan: outcomeChan, + systemDB: dbosCtx.(*dbosContext).systemDB, } // Create workflow state to track step execution @@ -515,16 +499,16 @@ func runAsWorkflow[P any, R any](ctx context.Context, dbosExecutor DBOSExecutor, } // Run the function in a goroutine - augmentUserContext := context.WithValue(ctx, workflowStateKey, wfState) - dbos.workflowsWg.Add(1) + augmentUserContext := dbosCtx.WithValue(workflowStateKey, wfState) + dbosCtx.GetWorkflowWg().Add(1) go func() { - defer dbos.workflowsWg.Done() + defer dbosCtx.GetWorkflowWg().Done() result, err := fn(augmentUserContext, input) status := WorkflowStatusSuccess if err != nil { status = WorkflowStatusError } - recordErr := dbos.systemDB.UpdateWorkflowOutcome(dbosWorkflowContext, updateWorkflowOutcomeDBInput{workflowID: workflowStatus.ID, status: status, err: err, output: result}) + recordErr := dbosCtx.UpdateWorkflowOutcome(workflowStatus.ID, status, err, result) if recordErr != nil { outcomeChan <- workflowOutcome[R]{result: *new(R), err: recordErr} close(outcomeChan) // Close the channel to signal completion @@ -534,20 +518,10 @@ func runAsWorkflow[P any, R any](ctx context.Context, dbosExecutor DBOSExecutor, close(outcomeChan) // Close the channel to signal completion }() - // Run the peer goroutine to handle cancellation and timeout - /* - if dbosWorkflowContext.Done() != nil { - getLogger().Debug("starting goroutine to handle workflow context cancellation or timeout") - go func() { - select { - case <-dbosWorkflowContext.Done(): - // The context was cancelled or timed out: record timeout or cancellation - // CANCEL WORKFLOW HERE - return - } - }() - } - */ + // Commit the transaction + if err := insertStatusResult.tx.Commit(dbosCtx.GetContext()); err != nil { + return nil, newWorkflowExecutionError(workflowID, fmt.Sprintf("failed to commit transaction: %v", err)) + } return handle, nil } @@ -558,6 +532,8 @@ func runAsWorkflow[P any, R any](ctx context.Context, dbosExecutor DBOSExecutor, type StepFunc[P any, R any] func(ctx context.Context, input P) (R, error) +type TypeErasedStepFunc func(ctx context.Context, input any) (any, error) + type StepParams struct { MaxRetries int BackoffFactor float64 @@ -596,13 +572,11 @@ func WithMaxInterval(maxInterval time.Duration) stepOption { } } -func RunAsStep[P any, R any](ctx context.Context, fn StepFunc[P, R], input P, opts ...stepOption) (R, error) { +func (e *dbosContext) RunAsStep(fn TypeErasedStepFunc, input any, stepName string, opts ...stepOption) (any, error) { if fn == nil { - return *new(R), newStepExecutionError("", "", "step function cannot be nil") + return nil, newStepExecutionError("", "", "step function cannot be nil") } - stepName := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name() - // Apply options to build params with defaults params := StepParams{ MaxRetries: 0, @@ -615,30 +589,30 @@ func RunAsStep[P any, R any](ctx context.Context, fn StepFunc[P, R], input P, op } // Get workflow state from context - wfState, ok := ctx.Value(workflowStateKey).(*workflowState) + wfState, ok := e.ctx.Value(workflowStateKey).(*workflowState) if !ok || wfState == nil { - return *new(R), newStepExecutionError("", stepName, "workflow state not found in context: are you running this step within a workflow?") + return nil, newStepExecutionError("", stepName, "workflow state not found in context: are you running this step within a workflow?") } // If within a step, just run the function directly if wfState.isWithinStep { - return fn(ctx, input) + return fn(e.ctx, input) } // Get next step ID stepID := wfState.NextStepID() // Check the step is cancelled, has already completed, or is called with a different name - recordedOutput, err := dbos.systemDB.CheckOperationExecution(ctx, checkOperationExecutionDBInput{ + recordedOutput, err := e.systemDB.CheckOperationExecution(e.ctx, checkOperationExecutionDBInput{ workflowID: wfState.workflowID, stepID: stepID, stepName: stepName, }) if err != nil { - return *new(R), newStepExecutionError(wfState.workflowID, stepName, fmt.Sprintf("checking operation execution: %v", err)) + return nil, newStepExecutionError(wfState.workflowID, stepName, fmt.Sprintf("checking operation execution: %v", err)) } if recordedOutput != nil { - return recordedOutput.output.(R), recordedOutput.err + return recordedOutput.output, recordedOutput.err } // Execute step with retry logic if MaxRetries > 0 @@ -647,7 +621,7 @@ func RunAsStep[P any, R any](ctx context.Context, fn StepFunc[P, R], input P, op stepCounter: wfState.stepCounter, isWithinStep: true, } - stepCtx := context.WithValue(ctx, workflowStateKey, &stepState) + stepCtx := e.WithValue(workflowStateKey, &stepState) stepOutput, stepError := fn(stepCtx, input) @@ -668,8 +642,8 @@ func RunAsStep[P any, R any](ctx context.Context, fn StepFunc[P, R], input P, op // Wait before retry select { - case <-ctx.Done(): - return *new(R), newStepExecutionError(wfState.workflowID, stepName, fmt.Sprintf("context cancelled during retry: %v", ctx.Err())) + case <-e.ctx.Done(): + return nil, newStepExecutionError(wfState.workflowID, stepName, fmt.Sprintf("context cancelled during retry: %v", e.ctx.Err())) case <-time.After(delay): // Continue to retry } @@ -701,14 +675,45 @@ func RunAsStep[P any, R any](ctx context.Context, fn StepFunc[P, R], input P, op err: stepError, output: stepOutput, } - recErr := dbos.systemDB.RecordOperationResult(ctx, dbInput) + recErr := e.systemDB.RecordOperationResult(e.ctx, dbInput) if recErr != nil { - return *new(R), newStepExecutionError(wfState.workflowID, stepName, fmt.Sprintf("recording step outcome: %v", recErr)) + return nil, newStepExecutionError(wfState.workflowID, stepName, fmt.Sprintf("recording step outcome: %v", recErr)) } return stepOutput, stepError } +func RunAsStep[P any, R any](dbosCtx DBOSContext, fn StepFunc[P, R], input P, opts ...stepOption) (R, error) { + if fn == nil { + return *new(R), newStepExecutionError("", "", "step function cannot be nil") + } + + stepName := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name() + + // Type-erase the function + typeErasedFn := func(ctx context.Context, input any) (any, error) { + typedInput, ok := input.(P) + if !ok { + return nil, fmt.Errorf("unexpected input type: expected %T, got %T", *new(P), input) + } + return fn(ctx, typedInput) + } + + // Call the executor method + result, err := dbosCtx.RunAsStep(typeErasedFn, input, stepName, opts...) + if err != nil { + return *new(R), err + } + + // Type-check and cast the result + typedResult, ok := result.(R) + if !ok { + return *new(R), fmt.Errorf("unexpected result type: expected %T, got %T", *new(R), result) + } + + return typedResult, nil +} + /****************************************/ /******* WORKFLOW COMMUNICATIONS ********/ /****************************************/ @@ -719,15 +724,19 @@ type WorkflowSendInput[R any] struct { Topic string } +func (e *dbosContext) Send(input WorkflowSendInputInternal) error { + return e.systemDB.Send(e.ctx, input) +} + // Send sends a message to another workflow. // Send automatically registers the type of R for gob encoding -func Send[R any](ctx context.Context, input WorkflowSendInput[R]) error { +func Send[R any](dbosCtx DBOSContext, input WorkflowSendInput[R]) error { var typedMessage R gob.Register(typedMessage) - return dbos.systemDB.Send(ctx, workflowSendInputInternal{ - destinationID: input.DestinationID, - message: input.Message, - topic: input.Topic, + return dbosCtx.Send(WorkflowSendInputInternal{ + DestinationID: input.DestinationID, + Message: input.Message, + Topic: input.Topic, }) } @@ -736,8 +745,12 @@ type WorkflowRecvInput struct { Timeout time.Duration } -func Recv[R any](ctx context.Context, input WorkflowRecvInput) (R, error) { - msg, err := dbos.systemDB.Recv(ctx, input) +func (e *dbosContext) Recv(input WorkflowRecvInput) (any, error) { + return e.systemDB.Recv(e.ctx, input) +} + +func Recv[R any](dbosCtx DBOSContext, input WorkflowRecvInput) (R, error) { + msg, err := dbosCtx.Recv(input) if err != nil { return *new(R), err } @@ -758,15 +771,19 @@ type WorkflowSetEventInput[R any] struct { Message R } +func (e *dbosContext) SetEvent(input WorkflowSetEventInputInternal) error { + return e.systemDB.SetEvent(e.ctx, input) +} + // Sets an event from a workflow. // The event is a key value pair // SetEvent automatically registers the type of R for gob encoding -func SetEvent[R any](ctx context.Context, input WorkflowSetEventInput[R]) error { +func SetEvent[R any](dbosCtx DBOSContext, input WorkflowSetEventInput[R]) error { var typedMessage R gob.Register(typedMessage) - return dbos.systemDB.SetEvent(ctx, workflowSetEventInputInternal{ - key: input.Key, - message: input.Message, + return dbosCtx.SetEvent(WorkflowSetEventInputInternal{ + Key: input.Key, + Message: input.Message, }) } @@ -776,8 +793,12 @@ type WorkflowGetEventInput struct { Timeout time.Duration } -func GetEvent[R any](ctx context.Context, input WorkflowGetEventInput) (R, error) { - value, err := dbos.systemDB.GetEvent(ctx, input) +func (e *dbosContext) GetEvent(input WorkflowGetEventInput) (any, error) { + return e.systemDB.GetEvent(e.ctx, input) +} + +func GetEvent[R any](dbosCtx DBOSContext, input WorkflowGetEventInput) (R, error) { + value, err := dbosCtx.GetEvent(input) if err != nil { return *new(R), err } @@ -792,8 +813,12 @@ func GetEvent[R any](ctx context.Context, input WorkflowGetEventInput) (R, error return typedValue, nil } -func Sleep(ctx context.Context, duration time.Duration) (time.Duration, error) { - return dbos.systemDB.Sleep(ctx, duration) +func (e *dbosContext) Sleep(duration time.Duration) (time.Duration, error) { + return e.systemDB.Sleep(e.ctx, duration) +} + +func Sleep(dbosCtx DBOSContext, duration time.Duration) (time.Duration, error) { + return dbosCtx.Sleep(duration) } /***********************************/ @@ -801,24 +826,52 @@ func Sleep(ctx context.Context, duration time.Duration) (time.Duration, error) { /***********************************/ // GetWorkflowID retrieves the workflow ID from the context if called within a DBOS workflow -func GetWorkflowID(ctx context.Context) (string, error) { - wfState, ok := ctx.Value(workflowStateKey).(*workflowState) +func (e *dbosContext) GetWorkflowID() (string, error) { + wfState, ok := e.ctx.Value(workflowStateKey).(*workflowState) if !ok || wfState == nil { return "", errors.New("not within a DBOS workflow context") } return wfState.workflowID, nil } -func RetrieveWorkflow[R any](workflowID string) (workflowPollingHandle[R], error) { - ctx := context.Background() - workflowStatus, err := dbos.systemDB.ListWorkflows(ctx, listWorkflowsDBInput{ - workflowIDs: []string{workflowID}, +func (e *dbosContext) RetrieveWorkflow(workflowIDs []string) ([]WorkflowStatus, error) { + return e.systemDB.ListWorkflows(e.ctx, listWorkflowsDBInput{ + workflowIDs: workflowIDs, }) +} + +func RetrieveWorkflow[R any](dbosCtx DBOSContext, workflowID string) (workflowPollingHandle[R], error) { + workflowStatus, err := dbosCtx.RetrieveWorkflow([]string{workflowID}) if err != nil { return workflowPollingHandle[R]{}, fmt.Errorf("failed to retrieve workflow status: %w", err) } if len(workflowStatus) == 0 { return workflowPollingHandle[R]{}, newNonExistentWorkflowError(workflowID) } - return workflowPollingHandle[R]{workflowID: workflowID}, nil + return workflowPollingHandle[R]{workflowID: workflowID, systemDB: dbosCtx.(*dbosContext).systemDB}, nil +} + +func (e *dbosContext) CheckChildWorkflow(parentWorkflowID string, stepCounter int) (*string, error) { + return e.systemDB.CheckChildWorkflow(e.ctx, parentWorkflowID, stepCounter) +} + +func (e *dbosContext) InsertWorkflowStatus(status WorkflowStatus, maxRetries int) (*insertWorkflowResult, error) { + return e.systemDB.InsertWorkflowStatus(e.ctx, insertWorkflowStatusDBInput{ + status: status, + maxRetries: maxRetries, + }) +} + +func (e *dbosContext) RecordChildWorkflow(input recordChildWorkflowDBInput) error { + return e.systemDB.RecordChildWorkflow(e.ctx, input) +} + +func (e *dbosContext) UpdateWorkflowOutcome(workflowID string, status WorkflowStatusType, err error, output any) error { + return e.systemDB.UpdateWorkflowOutcome(e.ctx, updateWorkflowOutcomeDBInput{ + workflowID: workflowID, + status: status, + err: err, + output: output, + tx: nil, // No explicit transaction for this interface method + }) } diff --git a/dbos/workflows_test.go b/dbos/workflows_test.go index 0789df9f..1fd2dcdc 100644 --- a/dbos/workflows_test.go +++ b/dbos/workflows_test.go @@ -33,7 +33,7 @@ func simpleWorkflowError(ctx context.Context, input string) (int, error) { } func simpleWorkflowWithStep(ctx context.Context, input string) (string, error) { - return RunAsStep(ctx, simpleStep, input) + return RunAsStep(ctx, dbos, simpleStep, input) } func simpleStep(ctx context.Context, input string) (string, error) { @@ -45,7 +45,7 @@ func simpleStepError(ctx context.Context, input string) (string, error) { } func simpleWorkflowWithStepError(ctx context.Context, input string) (string, error) { - return RunAsStep(ctx, simpleStepError, input) + return RunAsStep(ctx, dbos, simpleStepError, input) } // idempotencyWorkflow increments a global counter and returns the input @@ -89,28 +89,28 @@ func TestWorkflowsWrapping(t *testing.T) { executor := setupDBOS(t) // Setup workflows with executor - simpleWf := WithWorkflow(executor, simpleWorkflow) - simpleWfError := WithWorkflow(executor, simpleWorkflowError) - simpleWfWithStep := WithWorkflow(executor, simpleWorkflowWithStep) - simpleWfWithStepError := WithWorkflow(executor, simpleWorkflowWithStepError) + simpleWf := RegisterWorkflow(executor, simpleWorkflow) + simpleWfError := RegisterWorkflow(executor, simpleWorkflowError) + simpleWfWithStep := RegisterWorkflow(executor, simpleWorkflowWithStep) + simpleWfWithStepError := RegisterWorkflow(executor, simpleWorkflowWithStepError) // struct methods s := workflowStruct{} - simpleWfStruct := WithWorkflow(executor, s.simpleWorkflow) - simpleWfValue := WithWorkflow(executor, s.simpleWorkflowValue) + simpleWfStruct := RegisterWorkflow(executor, s.simpleWorkflow) + simpleWfValue := RegisterWorkflow(executor, s.simpleWorkflowValue) // interface method workflow workflowIface := TestWorkflowInterface(&workflowImplementation{ field: "example", }) - simpleWfIface := WithWorkflow(executor, workflowIface.Execute) + simpleWfIface := RegisterWorkflow(executor, workflowIface.Execute) // Generic workflow - wfInt := WithWorkflow(executor, Identity[string]) // FIXME make this an int eventually + wfInt := RegisterWorkflow(executor, Identity[string]) // FIXME make this an int eventually // Closure with captured state prefix := "hello-" - wfClose := WithWorkflow(executor, func(ctx context.Context, in string) (string, error) { + wfClose := RegisterWorkflow(executor, func(ctx context.Context, in string) (string, error) { return prefix + in, nil }) // Anonymous workflow - anonymousWf := WithWorkflow(executor, func(ctx context.Context, in string) (string, error) { + anonymousWf := RegisterWorkflow(executor, func(ctx context.Context, in string) (string, error) { return "anonymous-" + in, nil }) @@ -290,11 +290,11 @@ func TestWorkflowsWrapping(t *testing.T) { } func stepWithinAStep(ctx context.Context, input string) (string, error) { - return RunAsStep(ctx, simpleStep, input) + return RunAsStep(ctx, dbos, simpleStep, input) } func stepWithinAStepWorkflow(ctx context.Context, input string) (string, error) { - return RunAsStep(ctx, stepWithinAStep, input) + return RunAsStep(ctx, dbos, stepWithinAStep, input) } // Global counter for retry testing @@ -313,8 +313,8 @@ func stepIdempotencyTest(ctx context.Context, input string) (string, error) { } func stepRetryWorkflow(ctx context.Context, input string) (string, error) { - RunAsStep(ctx, stepIdempotencyTest, input) - return RunAsStep(ctx, stepRetryAlwaysFailsStep, input, + RunAsStep(ctx, dbos, stepIdempotencyTest, input) + return RunAsStep(ctx, dbos, stepRetryAlwaysFailsStep, input, WithStepMaxRetries(5), WithBackoffFactor(2.0), WithBaseInterval(1*time.Millisecond), @@ -325,14 +325,14 @@ func TestSteps(t *testing.T) { executor := setupDBOS(t) // Create workflows with executor - stepWithinAStepWf := WithWorkflow(executor, stepWithinAStepWorkflow) - stepRetryWf := WithWorkflow(executor, stepRetryWorkflow) + stepWithinAStepWf := RegisterWorkflow(executor, stepWithinAStepWorkflow) + stepRetryWf := RegisterWorkflow(executor, stepRetryWorkflow) t.Run("StepsMustRunInsideWorkflows", func(t *testing.T) { ctx := context.Background() // Attempt to run a step outside of a workflow context - _, err := RunAsStep(ctx, simpleStep, "test") + _, err := RunAsStep(ctx, dbos, simpleStep, "test") if err == nil { t.Fatal("expected error when running step outside of workflow context, but got none") } @@ -455,7 +455,7 @@ func TestChildWorkflow(t *testing.T) { executor := setupDBOS(t) // Create child workflows with executor - childWf := WithWorkflow(executor, func(ctx context.Context, i int) (string, error) { + childWf := RegisterWorkflow(executor, func(ctx context.Context, i int) (string, error) { workflowID, err := GetWorkflowID(ctx) if err != nil { return "", fmt.Errorf("failed to get workflow ID: %v", err) @@ -465,10 +465,10 @@ func TestChildWorkflow(t *testing.T) { return "", fmt.Errorf("expected parentWf workflow ID to be %s, got %s", expectedCurrentID, workflowID) } // XXX right now the steps of a child workflow start with an incremented step ID, because the first step ID is allocated to the child workflow - return RunAsStep(ctx, simpleStep, "") + return RunAsStep(ctx, dbos, simpleStep, "") }) - parentWf := WithWorkflow(executor, func(ctx context.Context, i int) (string, error) { + parentWf := RegisterWorkflow(executor, func(ctx context.Context, i int) (string, error) { workflowID, err := GetWorkflowID(ctx) if err != nil { return "", fmt.Errorf("failed to get workflow ID: %v", err) @@ -494,7 +494,7 @@ func TestChildWorkflow(t *testing.T) { return childHandle.GetResult(ctx) }) - grandParentWf := WithWorkflow(executor, func(ctx context.Context, _ string) (string, error) { + grandParentWf := RegisterWorkflow(executor, func(ctx context.Context, _ string) (string, error) { for i := range 3 { workflowID, err := GetWorkflowID(ctx) if err != nil { @@ -563,15 +563,15 @@ func blockingStep(ctx context.Context, input string) (string, error) { var idempotencyWorkflowWithStepEvent *Event func idempotencyWorkflowWithStep(ctx context.Context, input string) (int64, error) { - RunAsStep(ctx, incrementCounter, 1) + RunAsStep(ctx, dbos, incrementCounter, 1) idempotencyWorkflowWithStepEvent.Set() - RunAsStep(ctx, blockingStep, input) + RunAsStep(ctx, dbos, blockingStep, input) return idempotencyCounter, nil } func TestWorkflowIdempotency(t *testing.T) { executor := setupDBOS(t) - idempotencyWf := WithWorkflow(executor, idempotencyWorkflow) + idempotencyWf := RegisterWorkflow(executor, idempotencyWorkflow) t.Run("WorkflowExecutedOnlyOnce", func(t *testing.T) { idempotencyCounter = 0 @@ -620,7 +620,7 @@ func TestWorkflowIdempotency(t *testing.T) { func TestWorkflowRecovery(t *testing.T) { executor := setupDBOS(t) - idempotencyWfWithStep := WithWorkflow(executor, idempotencyWorkflowWithStep) + idempotencyWfWithStep := RegisterWorkflow(executor, idempotencyWorkflowWithStep) t.Run("RecoveryResumeWhereItLeftOff", func(t *testing.T) { // Reset the global counter idempotencyCounter = 0 @@ -718,8 +718,8 @@ func infiniteDeadLetterQueueWorkflow(ctx context.Context, input string) (int, er } func TestWorkflowDeadLetterQueue(t *testing.T) { executor := setupDBOS(t) - deadLetterQueueWf := WithWorkflow(executor, deadLetterQueueWorkflow, WithMaxRetries(maxRecoveryAttempts)) - infiniteDeadLetterQueueWf := WithWorkflow(executor, infiniteDeadLetterQueueWorkflow, WithMaxRetries(-1)) // A negative value means infinite retries + deadLetterQueueWf := RegisterWorkflow(executor, deadLetterQueueWorkflow, WithMaxRetries(maxRecoveryAttempts)) + infiniteDeadLetterQueueWf := RegisterWorkflow(executor, infiniteDeadLetterQueueWorkflow, WithMaxRetries(-1)) // A negative value means infinite retries t.Run("DeadLetterQueueBehavior", func(t *testing.T) { deadLetterQueueEvent = NewEvent() @@ -892,7 +892,7 @@ var ( func TestScheduledWorkflows(t *testing.T) { executor := setupDBOS(t) - _ = WithWorkflow(executor, func(ctx context.Context, scheduledTime time.Time) (string, error) { + _ = RegisterWorkflow(executor, func(ctx context.Context, scheduledTime time.Time) (string, error) { startTime := time.Now() counter++ if counter == 10 { @@ -977,15 +977,15 @@ type sendWorkflowInput struct { } func sendWorkflow(ctx context.Context, input sendWorkflowInput) (string, error) { - err := Send(ctx, WorkflowSendInput[string]{DestinationID: input.DestinationID, Topic: input.Topic, Message: "message1"}) + err := Send(ctx, dbos, WorkflowSendInput[string]{DestinationID: input.DestinationID, Topic: input.Topic, Message: "message1"}) if err != nil { return "", err } - err = Send(ctx, WorkflowSendInput[string]{DestinationID: input.DestinationID, Topic: input.Topic, Message: "message2"}) + err = Send(ctx, dbos, WorkflowSendInput[string]{DestinationID: input.DestinationID, Topic: input.Topic, Message: "message2"}) if err != nil { return "", err } - err = Send(ctx, WorkflowSendInput[string]{DestinationID: input.DestinationID, Topic: input.Topic, Message: "message3"}) + err = Send(ctx, dbos, WorkflowSendInput[string]{DestinationID: input.DestinationID, Topic: input.Topic, Message: "message3"}) if err != nil { return "", err } @@ -993,15 +993,15 @@ func sendWorkflow(ctx context.Context, input sendWorkflowInput) (string, error) } func receiveWorkflow(ctx context.Context, topic string) (string, error) { - msg1, err := Recv[string](ctx, WorkflowRecvInput{Topic: topic, Timeout: 3 * time.Second}) + msg1, err := Recv[string](ctx, dbos, WorkflowRecvInput{Topic: topic, Timeout: 3 * time.Second}) if err != nil { return "", err } - msg2, err := Recv[string](ctx, WorkflowRecvInput{Topic: topic, Timeout: 3 * time.Second}) + msg2, err := Recv[string](ctx, dbos, WorkflowRecvInput{Topic: topic, Timeout: 3 * time.Second}) if err != nil { return "", err } - msg3, err := Recv[string](ctx, WorkflowRecvInput{Topic: topic, Timeout: 3 * time.Second}) + msg3, err := Recv[string](ctx, dbos, WorkflowRecvInput{Topic: topic, Timeout: 3 * time.Second}) if err != nil { return "", err } @@ -1020,7 +1020,7 @@ func receiveWorkflowCoordinated(ctx context.Context, input struct { concurrentRecvStartEvent.Wait() // Do a single Recv call with timeout - msg, err := Recv[string](ctx, WorkflowRecvInput{Topic: input.Topic, Timeout: 3 * time.Second}) + msg, err := Recv[string](ctx, dbos, WorkflowRecvInput{Topic: input.Topic, Timeout: 3 * time.Second}) if err != nil { return "", err } @@ -1029,16 +1029,16 @@ func receiveWorkflowCoordinated(ctx context.Context, input struct { func sendStructWorkflow(ctx context.Context, input sendWorkflowInput) (string, error) { testStruct := sendRecvType{Value: "test-struct-value"} - err := Send(ctx, WorkflowSendInput[sendRecvType]{DestinationID: input.DestinationID, Topic: input.Topic, Message: testStruct}) + err := Send(ctx, dbos, WorkflowSendInput[sendRecvType]{DestinationID: input.DestinationID, Topic: input.Topic, Message: testStruct}) return "", err } func receiveStructWorkflow(ctx context.Context, topic string) (sendRecvType, error) { - return Recv[sendRecvType](ctx, WorkflowRecvInput{Topic: topic, Timeout: 3 * time.Second}) + return Recv[sendRecvType](ctx, dbos, WorkflowRecvInput{Topic: topic, Timeout: 3 * time.Second}) } func sendIdempotencyWorkflow(ctx context.Context, input sendWorkflowInput) (string, error) { - err := Send(ctx, WorkflowSendInput[string]{DestinationID: input.DestinationID, Topic: input.Topic, Message: "m1"}) + err := Send(ctx, dbos, WorkflowSendInput[string]{DestinationID: input.DestinationID, Topic: input.Topic, Message: "m1"}) if err != nil { return "", err } @@ -1047,7 +1047,7 @@ func sendIdempotencyWorkflow(ctx context.Context, input sendWorkflowInput) (stri } func receiveIdempotencyWorkflow(ctx context.Context, topic string) (string, error) { - msg, err := Recv[string](ctx, WorkflowRecvInput{Topic: topic, Timeout: 3 * time.Second}) + msg, err := Recv[string](ctx, dbos, WorkflowRecvInput{Topic: topic, Timeout: 3 * time.Second}) if err != nil { return "", err } @@ -1057,7 +1057,7 @@ func receiveIdempotencyWorkflow(ctx context.Context, topic string) (string, erro } func stepThatCallsSend(ctx context.Context, input sendWorkflowInput) (string, error) { - err := Send(ctx, WorkflowSendInput[string]{ + err := Send(ctx, dbos, WorkflowSendInput[string]{ DestinationID: input.DestinationID, Topic: input.Topic, Message: "message-from-step", @@ -1069,7 +1069,7 @@ func stepThatCallsSend(ctx context.Context, input sendWorkflowInput) (string, er } func workflowThatCallsSendInStep(ctx context.Context, input sendWorkflowInput) (string, error) { - return RunAsStep(ctx, stepThatCallsSend, input) + return RunAsStep(ctx, dbos, stepThatCallsSend, input) } type sendRecvType struct { @@ -1080,14 +1080,14 @@ func TestSendRecv(t *testing.T) { executor := setupDBOS(t) // Register all send/recv workflows with executor - sendWf := WithWorkflow(executor, sendWorkflow) - receiveWf := WithWorkflow(executor, receiveWorkflow) - receiveWfCoordinated := WithWorkflow(executor, receiveWorkflowCoordinated) - sendStructWf := WithWorkflow(executor, sendStructWorkflow) - receiveStructWf := WithWorkflow(executor, receiveStructWorkflow) - sendIdempotencyWf := WithWorkflow(executor, sendIdempotencyWorkflow) - recvIdempotencyWf := WithWorkflow(executor, receiveIdempotencyWorkflow) - sendWithinStepWf := WithWorkflow(executor, workflowThatCallsSendInStep) + sendWf := RegisterWorkflow(executor, sendWorkflow) + receiveWf := RegisterWorkflow(executor, receiveWorkflow) + receiveWfCoordinated := RegisterWorkflow(executor, receiveWorkflowCoordinated) + sendStructWf := RegisterWorkflow(executor, sendStructWorkflow) + receiveStructWf := RegisterWorkflow(executor, receiveStructWorkflow) + sendIdempotencyWf := RegisterWorkflow(executor, sendIdempotencyWorkflow) + recvIdempotencyWf := RegisterWorkflow(executor, receiveIdempotencyWorkflow) + sendWithinStepWf := RegisterWorkflow(executor, workflowThatCallsSendInStep) t.Run("SendRecvSuccess", func(t *testing.T) { // Start the receive workflow @@ -1208,7 +1208,7 @@ func TestSendRecv(t *testing.T) { ctx := context.Background() // Attempt to run Recv outside of a workflow context - _, err := Recv[string](ctx, WorkflowRecvInput{Topic: "test-topic", Timeout: 1 * time.Second}) + _, err := Recv[string](ctx, dbos, WorkflowRecvInput{Topic: "test-topic", Timeout: 1 * time.Second}) if err == nil { t.Fatal("expected error when running Recv outside of workflow context, but got none") } @@ -1240,7 +1240,7 @@ func TestSendRecv(t *testing.T) { // Send messages from outside a workflow context (should work now) ctx := context.Background() for i := range 3 { - err = Send(ctx, WorkflowSendInput[string]{ + err = Send(ctx, dbos, WorkflowSendInput[string]{ DestinationID: receiveHandle.GetWorkflowID(), Topic: "outside-workflow-topic", Message: fmt.Sprintf("message%d", i+1), @@ -1467,7 +1467,7 @@ type setEventWorkflowInput struct { } func setEventWorkflow(ctx context.Context, input setEventWorkflowInput) (string, error) { - err := SetEvent(ctx, WorkflowSetEventInput[string]{Key: input.Key, Message: input.Message}) + err := SetEvent(ctx, dbos, WorkflowSetEventInput[string]{Key: input.Key, Message: input.Message}) if err != nil { return "", err } @@ -1475,7 +1475,7 @@ func setEventWorkflow(ctx context.Context, input setEventWorkflowInput) (string, } func getEventWorkflow(ctx context.Context, input setEventWorkflowInput) (string, error) { - result, err := GetEvent[string](ctx, WorkflowGetEventInput{ + result, err := GetEvent[string](ctx, dbos, WorkflowGetEventInput{ TargetWorkflowID: input.Key, // Reusing Key field as target workflow ID Key: input.Message, // Reusing Message field as event key Timeout: 3 * time.Second, @@ -1488,7 +1488,7 @@ func getEventWorkflow(ctx context.Context, input setEventWorkflowInput) (string, func setTwoEventsWorkflow(ctx context.Context, input setEventWorkflowInput) (string, error) { // Set the first event - err := SetEvent(ctx, WorkflowSetEventInput[string]{Key: "event1", Message: "first-event-message"}) + err := SetEvent(ctx, dbos, WorkflowSetEventInput[string]{Key: "event1", Message: "first-event-message"}) if err != nil { return "", err } @@ -1497,7 +1497,7 @@ func setTwoEventsWorkflow(ctx context.Context, input setEventWorkflowInput) (str setSecondEventSignal.Wait() // Set the second event - err = SetEvent(ctx, WorkflowSetEventInput[string]{Key: "event2", Message: "second-event-message"}) + err = SetEvent(ctx, dbos, WorkflowSetEventInput[string]{Key: "event2", Message: "second-event-message"}) if err != nil { return "", err } @@ -1506,7 +1506,7 @@ func setTwoEventsWorkflow(ctx context.Context, input setEventWorkflowInput) (str } func setEventIdempotencyWorkflow(ctx context.Context, input setEventWorkflowInput) (string, error) { - err := SetEvent(ctx, WorkflowSetEventInput[string]{Key: input.Key, Message: input.Message}) + err := SetEvent(ctx, dbos, WorkflowSetEventInput[string]{Key: input.Key, Message: input.Message}) if err != nil { return "", err } @@ -1515,7 +1515,7 @@ func setEventIdempotencyWorkflow(ctx context.Context, input setEventWorkflowInpu } func getEventIdempotencyWorkflow(ctx context.Context, input setEventWorkflowInput) (string, error) { - result, err := GetEvent[string](ctx, WorkflowGetEventInput{ + result, err := GetEvent[string](ctx, dbos, WorkflowGetEventInput{ TargetWorkflowID: input.Key, Key: input.Message, Timeout: 3 * time.Second, @@ -1532,11 +1532,11 @@ func TestSetGetEvent(t *testing.T) { executor := setupDBOS(t) // Register all set/get event workflows with executor - setEventWf := WithWorkflow(executor, setEventWorkflow) - getEventWf := WithWorkflow(executor, getEventWorkflow) - setTwoEventsWf := WithWorkflow(executor, setTwoEventsWorkflow) - setEventIdempotencyWf := WithWorkflow(executor, setEventIdempotencyWorkflow) - getEventIdempotencyWf := WithWorkflow(executor, getEventIdempotencyWorkflow) + setEventWf := RegisterWorkflow(executor, setEventWorkflow) + getEventWf := RegisterWorkflow(executor, getEventWorkflow) + setTwoEventsWf := RegisterWorkflow(executor, setTwoEventsWorkflow) + setEventIdempotencyWf := RegisterWorkflow(executor, setEventIdempotencyWorkflow) + getEventIdempotencyWf := RegisterWorkflow(executor, getEventIdempotencyWorkflow) t.Run("SetGetEventFromWorkflow", func(t *testing.T) { // Clear the signal event before starting @@ -1617,7 +1617,7 @@ func TestSetGetEvent(t *testing.T) { } // Start a workflow that gets the event from outside the original workflow - message, err := GetEvent[string](context.Background(), WorkflowGetEventInput{ + message, err := GetEvent[string](context.Background(), dbos, WorkflowGetEventInput{ TargetWorkflowID: setHandle.GetWorkflowID(), Key: "test-key", Timeout: 3 * time.Second, @@ -1633,7 +1633,7 @@ func TestSetGetEvent(t *testing.T) { t.Run("GetEventTimeout", func(t *testing.T) { // Try to get an event from a non-existent workflow nonExistentID := uuid.NewString() - message, err := GetEvent[string](context.Background(), WorkflowGetEventInput{ + message, err := GetEvent[string](context.Background(), dbos, WorkflowGetEventInput{ TargetWorkflowID: nonExistentID, Key: "test-key", Timeout: 3 * time.Second, @@ -1657,7 +1657,7 @@ func TestSetGetEvent(t *testing.T) { if err != nil { t.Fatal("failed to get result from set event workflow:", err) } - message, err = GetEvent[string](context.Background(), WorkflowGetEventInput{ + message, err = GetEvent[string](context.Background(), dbos, WorkflowGetEventInput{ TargetWorkflowID: setHandle.GetWorkflowID(), Key: "non-existent-key", Timeout: 3 * time.Second, @@ -1674,7 +1674,7 @@ func TestSetGetEvent(t *testing.T) { ctx := context.Background() // Attempt to run SetEvent outside of a workflow context - err := SetEvent(ctx, WorkflowSetEventInput[string]{Key: "test-key", Message: "test-message"}) + err := SetEvent(ctx, dbos, WorkflowSetEventInput[string]{Key: "test-key", Message: "test-message"}) if err == nil { t.Fatal("expected error when running SetEvent outside of workflow context, but got none") } @@ -1791,7 +1791,7 @@ func TestSetGetEvent(t *testing.T) { for range numGoroutines { go func() { defer wg.Done() - res, err := GetEvent[string](context.Background(), WorkflowGetEventInput{ + res, err := GetEvent[string](context.Background(), dbos, WorkflowGetEventInput{ TargetWorkflowID: setHandle.GetWorkflowID(), Key: "concurrent-event-key", Timeout: 10 * time.Second, @@ -1822,7 +1822,7 @@ var ( ) func sleepRecoveryWorkflow(ctx context.Context, duration time.Duration) (time.Duration, error) { - result, err := Sleep(ctx, duration) + result, err := Sleep(ctx, dbos, duration) if err != nil { return 0, err } @@ -1834,7 +1834,7 @@ func sleepRecoveryWorkflow(ctx context.Context, duration time.Duration) (time.Du func TestSleep(t *testing.T) { executor := setupDBOS(t) - sleepRecoveryWf := WithWorkflow(executor, sleepRecoveryWorkflow) + sleepRecoveryWf := RegisterWorkflow(executor, sleepRecoveryWorkflow) t.Run("SleepDurableRecovery", func(t *testing.T) { sleepStartEvent = NewEvent() @@ -1891,7 +1891,7 @@ func TestSleep(t *testing.T) { ctx := context.Background() // Attempt to call Sleep outside of a workflow context - _, err := Sleep(ctx, 1*time.Second) + _, err := Sleep(ctx, dbos, 1*time.Second) if err == nil { t.Fatal("expected error when calling Sleep outside of workflow context, but got none") } From 82004427fb7583e8ba48c513da66fb3dc7963b22 Mon Sep 17 00:00:00 2001 From: maxdml Date: Wed, 30 Jul 2025 13:32:42 -0700 Subject: [PATCH 04/30] new interface --- dbos/dbos.go | 30 +++--- dbos/queues_test.go | 12 +-- dbos/recovery.go | 6 +- dbos/serialization_test.go | 4 +- dbos/system_database.go | 52 +++++------ dbos/workflow.go | 184 ++++++++++++++++++++++++------------- dbos/workflows_test.go | 24 ++--- 7 files changed, 184 insertions(+), 128 deletions(-) diff --git a/dbos/dbos.go b/dbos/dbos.go index 2ffb8fa4..2d01b55e 100644 --- a/dbos/dbos.go +++ b/dbos/dbos.go @@ -69,11 +69,11 @@ type DBOSContext interface { WithValue(key, val any) DBOSContext // Workflow registration - RegisterWorkflow(fqn string, fn typedErasedWorkflowWrapperFunc, maxRetries int) - RegisterScheduledWorkflow(fqn string, fn typedErasedWorkflowWrapperFunc, cronSchedule string, maxRetries int) + RegisterWorkflow(fqn string, fn WrappedWorkflowFunc, maxRetries int) + RegisterScheduledWorkflow(fqn string, fn WrappedWorkflowFunc, cronSchedule string, maxRetries int) // Workflow operations - RunAsStep(fn TypeErasedStepFunc, input any, stepName string, opts ...stepOption) (any, error) + RunAsStep(fn StepFunc, input any, stepName string, opts ...StepOption) (any, error) Send(input WorkflowSendInputInternal) error Recv(input WorkflowRecvInput) (any, error) SetEvent(input WorkflowSetEventInputInternal) error @@ -81,14 +81,19 @@ type DBOSContext interface { Sleep(duration time.Duration) (time.Duration, error) // Workflow management + InsertWorkflowStatus(status WorkflowStatus, maxRetries int) (*InsertWorkflowResult, error) + UpdateWorkflowOutcome(workflowID string, status WorkflowStatusType, err error, output any) error RetrieveWorkflow(workflowIDs []string) ([]WorkflowStatus, error) + ListWorkflows(input ListWorkflowsDBInput) ([]WorkflowStatus, error) CheckChildWorkflow(parentWorkflowID string, stepCounter int) (*string, error) - InsertWorkflowStatus(status WorkflowStatus, maxRetries int) (*insertWorkflowResult, error) - RecordChildWorkflow(input recordChildWorkflowDBInput) error - UpdateWorkflowOutcome(workflowID string, status WorkflowStatusType, err error, output any) error + RecordChildWorkflow(input RecordChildWorkflowDBInput) error + CheckOperationExecution(workflowID string, stepID int, stepName string) (*RecordedResult, error) + RecordOperationResult(input RecordOperationResultDBInput) error + RecordChildGetResult(input RecordChildGetResultDBInput) error // Context operations GetWorkflowID() (string, error) + AwaitWorkflowResult(workflowID string) (any, error) // Accessors GetWorkflowScheduler() *cron.Cron @@ -191,8 +196,10 @@ func (e *dbosContext) GetWorkflowWg() *sync.WaitGroup { func NewDBOSContext(inputConfig Config) (DBOSContext, error) { initExecutor := &dbosContext{ - workflowsWg: &sync.WaitGroup{}, - ctx: context.Background(), + workflowsWg: &sync.WaitGroup{}, + ctx: context.Background(), + workflowRegistry: make(map[string]workflowRegistryEntry), + workflowRegMutex: &sync.RWMutex{}, } // Load & process the configuration @@ -232,9 +239,6 @@ func NewDBOSContext(inputConfig Config) (DBOSContext, error) { initExecutor.systemDB = systemDB logger.Info("System database initialized") - // Initialize the workflow registry - initExecutor.workflowRegistry = make(map[string]workflowRegistryEntry) - return initExecutor, nil } @@ -255,8 +259,8 @@ func (e *dbosContext) Launch() error { } // Create context with cancel function for queue runner - // XXX this can now be a cancel function on the executor itself? - ctx, cancel := context.WithCancel(context.Background()) + // FIXME: cancellation now has to go through the DBOSContext + ctx, cancel := context.WithCancel(e.GetContext()) e.queueRunnerCtx = ctx e.queueRunnerCancelFunc = cancel e.queueRunnerDone = make(chan struct{}) diff --git a/dbos/queues_test.go b/dbos/queues_test.go index d77a8c19..682e537d 100644 --- a/dbos/queues_test.go +++ b/dbos/queues_test.go @@ -257,8 +257,8 @@ func TestQueueRecovery(t *testing.T) { executor := setupDBOS(t) // Create workflows with executor - var recoveryStepWorkflow func(context.Context, int, ...workflowOption) (WorkflowHandle[int], error) - var recoveryWorkflow func(context.Context, string, ...workflowOption) (WorkflowHandle[[]int], error) + var recoveryStepWorkflow func(context.Context, int, ...WorkflowOption) (WorkflowHandle[int], error) + var recoveryWorkflow func(context.Context, string, ...WorkflowOption) (WorkflowHandle[[]int], error) recoveryStepWorkflow = RegisterWorkflow(executor, func(ctx context.Context, i int) (int, error) { recoveryStepCounter++ @@ -505,7 +505,7 @@ func TestWorkerConcurrency(t *testing.T) { if startEvents[1].IsSet || startEvents[2].IsSet || startEvents[3].IsSet { t.Fatal("expected only blocking workflow 1 to start, but others have started") } - workflows, err := dbos.systemDB.ListWorkflows(context.Background(), listWorkflowsDBInput{ + workflows, err := dbos.systemDB.ListWorkflows(context.Background(), ListWorkflowsDBInput{ status: []WorkflowStatusType{WorkflowStatusEnqueued}, queueName: workerConcurrencyQueue.name, }) @@ -529,7 +529,7 @@ func TestWorkerConcurrency(t *testing.T) { if startEvents[2].IsSet || startEvents[3].IsSet { t.Fatal("expected only blocking workflow 2 to start, but others have started") } - workflows, err = dbos.systemDB.ListWorkflows(context.Background(), listWorkflowsDBInput{ + workflows, err = dbos.systemDB.ListWorkflows(context.Background(), ListWorkflowsDBInput{ status: []WorkflowStatusType{WorkflowStatusEnqueued}, queueName: workerConcurrencyQueue.name, }) @@ -561,7 +561,7 @@ func TestWorkerConcurrency(t *testing.T) { t.Fatal("expected only blocking workflow 3 to start, but workflow 4 has started") } // Check that only one workflow is pending - workflows, err = dbos.systemDB.ListWorkflows(context.Background(), listWorkflowsDBInput{ + workflows, err = dbos.systemDB.ListWorkflows(context.Background(), ListWorkflowsDBInput{ status: []WorkflowStatusType{WorkflowStatusEnqueued}, queueName: workerConcurrencyQueue.name, }) @@ -589,7 +589,7 @@ func TestWorkerConcurrency(t *testing.T) { restartQueueRunner() startEvents[3].Wait() // Check no workflow is enqueued - workflows, err = dbos.systemDB.ListWorkflows(context.Background(), listWorkflowsDBInput{ + workflows, err = dbos.systemDB.ListWorkflows(context.Background(), ListWorkflowsDBInput{ status: []WorkflowStatusType{WorkflowStatusEnqueued}, queueName: workerConcurrencyQueue.name, }) diff --git a/dbos/recovery.go b/dbos/recovery.go index a529311c..b6e543ce 100644 --- a/dbos/recovery.go +++ b/dbos/recovery.go @@ -7,7 +7,7 @@ import ( func recoverPendingWorkflows(dbosCtx *dbosContext, executorIDs []string) ([]WorkflowHandle[any], error) { workflowHandles := make([]WorkflowHandle[any], 0) // List pending workflows for the executors - pendingWorkflows, err := dbosCtx.systemDB.ListWorkflows(dbosCtx.GetContext(), listWorkflowsDBInput{ + pendingWorkflows, err := dbosCtx.systemDB.ListWorkflows(dbosCtx.GetContext(), ListWorkflowsDBInput{ status: []WorkflowStatusType{WorkflowStatusPending}, executorIDs: executorIDs, applicationVersion: dbosCtx.applicationVersion, @@ -32,7 +32,7 @@ func recoverPendingWorkflows(dbosCtx *dbosContext, executorIDs []string) ([]Work continue } if cleared { - workflowHandles = append(workflowHandles, &workflowPollingHandle[any]{workflowID: workflow.ID, systemDB: dbosCtx.systemDB}) + workflowHandles = append(workflowHandles, &workflowPollingHandle[any]{workflowID: workflow.ID, dbosContext: dbosCtx}) } continue } @@ -44,7 +44,7 @@ func recoverPendingWorkflows(dbosCtx *dbosContext, executorIDs []string) ([]Work } // Convert workflow parameters to options - opts := []workflowOption{ + opts := []WorkflowOption{ WithWorkflowID(workflow.ID), } // XXX we'll figure out the exact timeout/deadline settings later diff --git a/dbos/serialization_test.go b/dbos/serialization_test.go index 90332191..22900510 100644 --- a/dbos/serialization_test.go +++ b/dbos/serialization_test.go @@ -99,7 +99,7 @@ func TestWorkflowEncoding(t *testing.T) { } // Test results from ListWorkflows - workflows, err := dbos.systemDB.ListWorkflows(context.Background(), listWorkflowsDBInput{ + workflows, err := dbos.systemDB.ListWorkflows(context.Background(), ListWorkflowsDBInput{ workflowIDs: []string{directHandle.GetWorkflowID()}, }) if err != nil { @@ -216,7 +216,7 @@ func TestWorkflowEncoding(t *testing.T) { } // Test results from ListWorkflows - workflows, err := dbos.systemDB.ListWorkflows(context.Background(), listWorkflowsDBInput{ + workflows, err := dbos.systemDB.ListWorkflows(context.Background(), ListWorkflowsDBInput{ workflowIDs: []string{directHandle.GetWorkflowID()}, }) if err != nil { diff --git a/dbos/system_database.go b/dbos/system_database.go index 46f964d9..c2177bfb 100644 --- a/dbos/system_database.go +++ b/dbos/system_database.go @@ -29,19 +29,19 @@ type SystemDatabase interface { ResetSystemDB(ctx context.Context) error // Workflows - InsertWorkflowStatus(ctx context.Context, input insertWorkflowStatusDBInput) (*insertWorkflowResult, error) - ListWorkflows(ctx context.Context, input listWorkflowsDBInput) ([]WorkflowStatus, error) + InsertWorkflowStatus(ctx context.Context, input insertWorkflowStatusDBInput) (*InsertWorkflowResult, error) + ListWorkflows(ctx context.Context, input ListWorkflowsDBInput) ([]WorkflowStatus, error) UpdateWorkflowOutcome(ctx context.Context, input updateWorkflowOutcomeDBInput) error AwaitWorkflowResult(ctx context.Context, workflowID string) (any, error) // Child workflows - RecordChildWorkflow(ctx context.Context, input recordChildWorkflowDBInput) error + RecordChildWorkflow(ctx context.Context, input RecordChildWorkflowDBInput) error CheckChildWorkflow(ctx context.Context, workflowUUID string, functionID int) (*string, error) - RecordChildGetResult(ctx context.Context, input recordChildGetResultDBInput) error + RecordChildGetResult(ctx context.Context, input RecordChildGetResultDBInput) error // Steps - RecordOperationResult(ctx context.Context, input recordOperationResultDBInput) error - CheckOperationExecution(ctx context.Context, input checkOperationExecutionDBInput) (*recordedResult, error) + RecordOperationResult(ctx context.Context, input RecordOperationResultDBInput) error + CheckOperationExecution(ctx context.Context, input checkOperationExecutionDBInput) (*RecordedResult, error) GetWorkflowSteps(ctx context.Context, workflowID string) ([]StepInfo, error) // Communication (special steps) @@ -233,7 +233,7 @@ func (s *systemDatabase) Shutdown() { /******* WORKFLOWS ********/ /*******************************/ -type insertWorkflowResult struct { +type InsertWorkflowResult struct { attempts int status WorkflowStatusType name string @@ -247,7 +247,7 @@ type insertWorkflowStatusDBInput struct { maxRetries int } -func (s *systemDatabase) InsertWorkflowStatus(ctx context.Context, input insertWorkflowStatusDBInput) (*insertWorkflowResult, error) { +func (s *systemDatabase) InsertWorkflowStatus(ctx context.Context, input insertWorkflowStatusDBInput) (*InsertWorkflowResult, error) { tx, err := s.pool.Begin(ctx) if err != nil { return nil, fmt.Errorf("failed to begin transaction: %w", err) @@ -313,7 +313,7 @@ func (s *systemDatabase) InsertWorkflowStatus(ctx context.Context, input insertW END RETURNING recovery_attempts, status, name, queue_name, workflow_deadline_epoch_ms` - result := insertWorkflowResult{ + result := InsertWorkflowResult{ tx: tx, } err = tx.QueryRow(ctx, query, @@ -386,7 +386,7 @@ func (s *systemDatabase) InsertWorkflowStatus(ctx context.Context, input insertW } // ListWorkflowsInput represents the input parameters for listing workflows -type listWorkflowsDBInput struct { +type ListWorkflowsDBInput struct { workflowName string queueName string workflowIDPrefix string @@ -404,7 +404,7 @@ type listWorkflowsDBInput struct { } // ListWorkflows retrieves a list of workflows based on the provided filters -func (s *systemDatabase) ListWorkflows(ctx context.Context, input listWorkflowsDBInput) ([]WorkflowStatus, error) { +func (s *systemDatabase) ListWorkflows(ctx context.Context, input ListWorkflowsDBInput) ([]WorkflowStatus, error) { qb := newQueryBuilder() // Build the base query @@ -604,7 +604,7 @@ func (s *systemDatabase) CancelWorkflow(ctx context.Context, workflowID string) defer tx.Rollback(ctx) // Rollback if not committed // Check if workflow exists - listInput := listWorkflowsDBInput{ + listInput := ListWorkflowsDBInput{ workflowIDs: []string{workflowID}, tx: tx, } @@ -678,7 +678,7 @@ func (s *systemDatabase) AwaitWorkflowResult(ctx context.Context, workflowID str } } -type recordOperationResultDBInput struct { +type RecordOperationResultDBInput struct { workflowID string stepID int stepName string @@ -687,7 +687,7 @@ type recordOperationResultDBInput struct { tx pgx.Tx } -func (s *systemDatabase) RecordOperationResult(ctx context.Context, input recordOperationResultDBInput) error { +func (s *systemDatabase) RecordOperationResult(ctx context.Context, input RecordOperationResultDBInput) error { query := `INSERT INTO dbos.operation_outputs (workflow_uuid, function_id, output, error, function_name) VALUES ($1, $2, $3, $4, $5) @@ -748,7 +748,7 @@ func (s *systemDatabase) RecordOperationResult(ctx context.Context, input record /******* CHILD WORKFLOWS ********/ /*******************************/ -type recordChildWorkflowDBInput struct { +type RecordChildWorkflowDBInput struct { parentWorkflowID string childWorkflowID string stepID int @@ -756,7 +756,7 @@ type recordChildWorkflowDBInput struct { tx pgx.Tx } -func (s *systemDatabase) RecordChildWorkflow(ctx context.Context, input recordChildWorkflowDBInput) error { +func (s *systemDatabase) RecordChildWorkflow(ctx context.Context, input RecordChildWorkflowDBInput) error { query := `INSERT INTO dbos.operation_outputs (workflow_uuid, function_id, function_name, child_workflow_id) VALUES ($1, $2, $3, $4)` @@ -814,7 +814,7 @@ func (s *systemDatabase) CheckChildWorkflow(ctx context.Context, workflowID stri return childWorkflowID, nil } -type recordChildGetResultDBInput struct { +type RecordChildGetResultDBInput struct { parentWorkflowID string childWorkflowID string stepID int @@ -822,7 +822,7 @@ type recordChildGetResultDBInput struct { err error } -func (s *systemDatabase) RecordChildGetResult(ctx context.Context, input recordChildGetResultDBInput) error { +func (s *systemDatabase) RecordChildGetResult(ctx context.Context, input RecordChildGetResultDBInput) error { query := `INSERT INTO dbos.operation_outputs (workflow_uuid, function_id, function_name, output, error, child_workflow_id) VALUES ($1, $2, $3, $4, $5, $6) @@ -852,7 +852,7 @@ func (s *systemDatabase) RecordChildGetResult(ctx context.Context, input recordC /******* STEPS ********/ /*******************************/ -type recordedResult struct { +type RecordedResult struct { output any err error } @@ -864,7 +864,7 @@ type checkOperationExecutionDBInput struct { tx pgx.Tx } -func (s *systemDatabase) CheckOperationExecution(ctx context.Context, input checkOperationExecutionDBInput) (*recordedResult, error) { +func (s *systemDatabase) CheckOperationExecution(ctx context.Context, input checkOperationExecutionDBInput) (*RecordedResult, error) { var tx pgx.Tx var err error @@ -932,7 +932,7 @@ func (s *systemDatabase) CheckOperationExecution(ctx context.Context, input chec if errorStr != nil && *errorStr != "" { recordedError = errors.New(*errorStr) } - result := &recordedResult{ + result := &RecordedResult{ output: output, err: recordedError, } @@ -1052,7 +1052,7 @@ func (s *systemDatabase) Sleep(ctx context.Context, duration time.Duration) (tim endTime = time.Now().Add(duration) // Record the operation result with the calculated end time - recordInput := recordOperationResultDBInput{ + recordInput := RecordOperationResultDBInput{ workflowID: wfState.workflowID, stepID: stepID, stepName: functionName, @@ -1191,7 +1191,7 @@ func (s *systemDatabase) Send(ctx context.Context, input WorkflowSendInputIntern // Record the operation result if this is called within a workflow if isInWorkflow { - recordInput := recordOperationResultDBInput{ + recordInput := RecordOperationResultDBInput{ workflowID: wfState.workflowID, stepID: stepID, stepName: functionName, @@ -1341,7 +1341,7 @@ func (s *systemDatabase) Recv(ctx context.Context, input WorkflowRecvInput) (any } // Record the operation result - recordInput := recordOperationResultDBInput{ + recordInput := RecordOperationResultDBInput{ workflowID: destinationID, stepID: stepID, stepName: functionName, @@ -1420,7 +1420,7 @@ func (s *systemDatabase) SetEvent(ctx context.Context, input WorkflowSetEventInp } // Record the operation result - recordInput := recordOperationResultDBInput{ + recordInput := RecordOperationResultDBInput{ workflowID: wfState.workflowID, stepID: stepID, stepName: functionName, @@ -1543,7 +1543,7 @@ func (s *systemDatabase) GetEvent(ctx context.Context, input WorkflowGetEventInp // Record the operation result if this is called within a workflow if isInWorkflow { - recordInput := recordOperationResultDBInput{ + recordInput := RecordOperationResultDBInput{ workflowID: wfState.workflowID, stepID: stepID, stepName: functionName, diff --git a/dbos/workflow.go b/dbos/workflow.go index b4501021..abf0d9dd 100644 --- a/dbos/workflow.go +++ b/dbos/workflow.go @@ -77,7 +77,7 @@ type workflowOutcome[R any] struct { } type WorkflowHandle[R any] interface { - GetResult(ctx context.Context) (R, error) + GetResult() (R, error) GetStatus() (WorkflowStatus, error) GetWorkflowID() string } @@ -86,32 +86,32 @@ type WorkflowHandle[R any] interface { type workflowHandle[R any] struct { workflowID string outcomeChan chan workflowOutcome[R] - systemDB SystemDatabase + dbosContext DBOSContext } // GetResult waits for the workflow to complete and returns the result -func (h *workflowHandle[R]) GetResult(ctx context.Context) (R, error) { +func (h *workflowHandle[R]) GetResult() (R, error) { outcome, ok := <-h.outcomeChan // Blocking read if !ok { // Return an error if the channel was closed. In normal operations this would happen if GetResul() is called twice on a handler. The first call should get the buffered result, the second call find zero values (channel is empty and closed). return *new(R), errors.New("workflow result channel is already closed. Did you call GetResult() twice on the same workflow handle?") } // If we are calling GetResult inside a workflow, record the result as a step result - parentWorkflowState, ok := ctx.Value(workflowStateKey).(*workflowState) + parentWorkflowState, ok := h.dbosContext.GetContext().Value(workflowStateKey).(*workflowState) isChildWorkflow := ok && parentWorkflowState != nil if isChildWorkflow { encodedOutput, encErr := serialize(outcome.result) if encErr != nil { return *new(R), newWorkflowExecutionError(parentWorkflowState.workflowID, fmt.Sprintf("serializing child workflow result: %v", encErr)) } - recordGetResultInput := recordChildGetResultDBInput{ + recordGetResultInput := RecordChildGetResultDBInput{ parentWorkflowID: parentWorkflowState.workflowID, childWorkflowID: h.workflowID, stepID: parentWorkflowState.NextStepID(), output: encodedOutput, err: outcome.err, } - recordResultErr := h.systemDB.RecordChildGetResult(ctx, recordGetResultInput) + recordResultErr := h.dbosContext.RecordChildGetResult(recordGetResultInput) if recordResultErr != nil { getLogger().Error("failed to record get result", "error", recordResultErr) return *new(R), newWorkflowExecutionError(parentWorkflowState.workflowID, fmt.Sprintf("recording child workflow result: %v", recordResultErr)) @@ -122,8 +122,7 @@ func (h *workflowHandle[R]) GetResult(ctx context.Context) (R, error) { // GetStatus returns the current status of the workflow from the database func (h *workflowHandle[R]) GetStatus() (WorkflowStatus, error) { - ctx := context.Background() - workflowStatuses, err := h.systemDB.ListWorkflows(ctx, listWorkflowsDBInput{ + workflowStatuses, err := h.dbosContext.ListWorkflows(ListWorkflowsDBInput{ workflowIDs: []string{h.workflowID}, }) if err != nil { @@ -140,12 +139,13 @@ func (h *workflowHandle[R]) GetWorkflowID() string { } type workflowPollingHandle[R any] struct { - workflowID string - systemDB SystemDatabase + workflowID string + dbosContext DBOSContext } -func (h *workflowPollingHandle[R]) GetResult(ctx context.Context) (R, error) { - result, err := h.systemDB.AwaitWorkflowResult(ctx, h.workflowID) +func (h *workflowPollingHandle[R]) GetResult() (R, error) { + ctx := context.Background() + result, err := h.dbosContext.AwaitWorkflowResult(h.workflowID) if result != nil { typedResult, ok := result.(R) if !ok { @@ -160,14 +160,14 @@ func (h *workflowPollingHandle[R]) GetResult(ctx context.Context) (R, error) { if encErr != nil { return *new(R), newWorkflowExecutionError(parentWorkflowState.workflowID, fmt.Sprintf("serializing child workflow result: %v", encErr)) } - recordGetResultInput := recordChildGetResultDBInput{ + recordGetResultInput := RecordChildGetResultDBInput{ parentWorkflowID: parentWorkflowState.workflowID, childWorkflowID: h.workflowID, stepID: parentWorkflowState.NextStepID(), output: encodedOutput, err: err, } - recordResultErr := h.systemDB.RecordChildGetResult(ctx, recordGetResultInput) + recordResultErr := h.dbosContext.RecordChildGetResult(recordGetResultInput) if recordResultErr != nil { // XXX do we want to fail this? getLogger().Error("failed to record get result", "error", recordResultErr) @@ -180,8 +180,7 @@ func (h *workflowPollingHandle[R]) GetResult(ctx context.Context) (R, error) { // GetStatus returns the current status of the workflow from the database func (h *workflowPollingHandle[R]) GetStatus() (WorkflowStatus, error) { - ctx := context.Background() - workflowStatuses, err := h.systemDB.ListWorkflows(ctx, listWorkflowsDBInput{ + workflowStatuses, err := h.dbosContext.ListWorkflows(ListWorkflowsDBInput{ workflowIDs: []string{h.workflowID}, }) if err != nil { @@ -200,15 +199,16 @@ func (h *workflowPollingHandle[R]) GetWorkflowID() string { /**********************************/ /******* WORKFLOW REGISTRY *******/ /**********************************/ -type typedErasedWorkflowWrapperFunc func(ctx DBOSContext, input any, opts ...workflowOption) (WorkflowHandle[any], error) +type GenericWrappedWorkflowFunc[P any, R any] func(dbosCtx DBOSContext, input P, opts ...WorkflowOption) (WorkflowHandle[R], error) +type WrappedWorkflowFunc func(ctx DBOSContext, input any, opts ...WorkflowOption) (WorkflowHandle[any], error) type workflowRegistryEntry struct { - wrappedFunction typedErasedWorkflowWrapperFunc + wrappedFunction WrappedWorkflowFunc maxRetries int } // Register adds a workflow function to the registry (thread-safe, only once per name) -func (e *dbosContext) RegisterWorkflow(fqn string, fn typedErasedWorkflowWrapperFunc, maxRetries int) { +func (e *dbosContext) RegisterWorkflow(fqn string, fn WrappedWorkflowFunc, maxRetries int) { e.workflowRegMutex.Lock() defer e.workflowRegMutex.Unlock() @@ -223,7 +223,7 @@ func (e *dbosContext) RegisterWorkflow(fqn string, fn typedErasedWorkflowWrapper } } -func (e *dbosContext) RegisterScheduledWorkflow(fqn string, fn typedErasedWorkflowWrapperFunc, cronSchedule string, maxRetries int) { +func (e *dbosContext) RegisterScheduledWorkflow(fqn string, fn WrappedWorkflowFunc, cronSchedule string, maxRetries int) { e.GetWorkflowScheduler().Start() var entryID cron.EntryID entryID, err := e.GetWorkflowScheduler().AddFunc(cronSchedule, func() { @@ -271,12 +271,14 @@ func WithSchedule(schedule string) workflowRegistrationOption { } } -// TODO split RegisterWorkflow and RegisterScheduledWorkflow -func RegisterWorkflow[P any, R any](dbosCtx DBOSContext, fn WorkflowFunc[P, R], opts ...workflowRegistrationOption) WorkflowWrapperFunc[P, R] { +// RegisterWorkflow wraps the provided function as a durable workflow and registers it with the provided DBOSContext workflow registry +// If the workflow is a scheduled workflow (determined by the presence of a cron schedule), it will also register a cron job to execute it +// RegisterWorkflow is generically typed, allowing us to register the workflow input and output types for gob encoding +// The registered workflow is wrapped in a typed-erased wrapper which performs runtime type checks and conversions +// RegisterWorkflow returns the statically typed wrapped function. The DBOSContext registry holds the typed-erased version +func RegisterWorkflow[P any, R any](dbosCtx DBOSContext, fn GenericWorkflowFunc[P, R], opts ...workflowRegistrationOption) GenericWrappedWorkflowFunc[P, R] { if dbosCtx == nil { - // TODO: consider panic here - getLogger().Error("Provide a valid DBOS context") - return nil + panic("dbosCtx cannot be nil") } registrationParams := workflowRegistrationParams{ @@ -298,13 +300,13 @@ func RegisterWorkflow[P any, R any](dbosCtx DBOSContext, fn WorkflowFunc[P, R], gob.Register(r) // Wrap the function in a durable workflow - wrappedFunction := WorkflowWrapperFunc[P, R](func(ctx DBOSContext, workflowInput P, opts ...workflowOption) (WorkflowHandle[R], error) { + wrappedFunction := GenericWrappedWorkflowFunc[P, R](func(ctx DBOSContext, workflowInput P, opts ...WorkflowOption) (WorkflowHandle[R], error) { opts = append(opts, WithWorkflowMaxRetries(registrationParams.maxRetries)) - return runAsWorkflow(ctx, fn, workflowInput, opts...) + return RunAsWorkflow(ctx, fn, workflowInput, opts...) }) // Register a type-erased version of the durable workflow for recovery - typeErasedWrapper := func(ctx DBOSContext, input any, opts ...workflowOption) (WorkflowHandle[any], error) { + typeErasedWrapper := func(ctx DBOSContext, input any, opts ...WorkflowOption) (WorkflowHandle[any], error) { typedInput, ok := input.(P) if !ok { return nil, newWorkflowUnexpectedInputType(fqn, fmt.Sprintf("%T", typedInput), fmt.Sprintf("%T", input)) @@ -314,7 +316,7 @@ func RegisterWorkflow[P any, R any](dbosCtx DBOSContext, fn WorkflowFunc[P, R], if err != nil { return nil, err } - return &workflowPollingHandle[any]{workflowID: handle.GetWorkflowID(), systemDB: ctx.(*dbosContext).systemDB}, nil + return &workflowPollingHandle[any]{workflowID: handle.GetWorkflowID(), dbosContext: ctx}, nil } dbosCtx.RegisterWorkflow(fqn, typeErasedWrapper, registrationParams.maxRetries) @@ -337,8 +339,7 @@ type contextKey string const workflowStateKey contextKey = "workflowState" -type WorkflowFunc[P any, R any] func(dbosCtx DBOSContext, input P) (R, error) -type WorkflowWrapperFunc[P any, R any] func(dbosCtx DBOSContext, input P, opts ...workflowOption) (WorkflowHandle[R], error) +type GenericWorkflowFunc[P any, R any] func(dbosCtx DBOSContext, input P) (R, error) type workflowParams struct { workflowID string @@ -349,45 +350,53 @@ type workflowParams struct { maxRetries int } -type workflowOption func(*workflowParams) +type WorkflowOption func(*workflowParams) -func WithWorkflowID(id string) workflowOption { +func WithWorkflowID(id string) WorkflowOption { return func(p *workflowParams) { p.workflowID = id } } -func WithTimeout(timeout time.Duration) workflowOption { +func WithTimeout(timeout time.Duration) WorkflowOption { return func(p *workflowParams) { p.timeout = timeout } } -func WithDeadline(deadline time.Time) workflowOption { +func WithDeadline(deadline time.Time) WorkflowOption { return func(p *workflowParams) { p.deadline = deadline } } -func WithQueue(queueName string) workflowOption { +func WithQueue(queueName string) WorkflowOption { return func(p *workflowParams) { p.queueName = queueName } } -func WithApplicationVersion(version string) workflowOption { +func WithApplicationVersion(version string) WorkflowOption { return func(p *workflowParams) { p.applicationVersion = version } } -func WithWorkflowMaxRetries(maxRetries int) workflowOption { +func WithWorkflowMaxRetries(maxRetries int) WorkflowOption { return func(p *workflowParams) { p.maxRetries = maxRetries } } -func runAsWorkflow[P any, R any](dbosCtx DBOSContext, fn WorkflowFunc[P, R], input P, opts ...workflowOption) (WorkflowHandle[R], error) { +// RunAsWorkflow executes the provided function as a durable workflow +// It handles all the features of durable execution for workflows: +// - Workflow ID generation, if needed +// - Child workflow management +// - Enqueuing if a queue name is specified +// - Consistent recording in DBOS system database, returning a handle if the workflow is a terminal status +// The wrapped function is ran into a goroutine and the result is sent to a channel. +// RunAsWorkflow returns a workflow handle that can be used to get the result or status of the workflow. +func RunAsWorkflow[P any, R any](dbosCtx DBOSContext, fn GenericWorkflowFunc[P, R], input P, opts ...WorkflowOption) (WorkflowHandle[R], error) { // Apply options to build params params := workflowParams{ applicationVersion: dbosCtx.GetApplicationVersion(), @@ -422,7 +431,7 @@ func runAsWorkflow[P any, R any](dbosCtx DBOSContext, fn WorkflowFunc[P, R], inp return nil, newWorkflowExecutionError(parentWorkflowState.workflowID, fmt.Sprintf("checking child workflow: %v", err)) } if childWorkflowID != nil { - return &workflowPollingHandle[R]{workflowID: *childWorkflowID, systemDB: dbosCtx.(*dbosContext).systemDB}, nil + return &workflowPollingHandle[R]{workflowID: *childWorkflowID, dbosContext: dbosCtx}, nil } } @@ -460,14 +469,14 @@ func runAsWorkflow[P any, R any](dbosCtx DBOSContext, fn WorkflowFunc[P, R], inp if err := insertStatusResult.tx.Commit(dbosCtx.GetContext()); err != nil { return nil, newWorkflowExecutionError(workflowID, fmt.Sprintf("failed to commit transaction: %v", err)) } - return &workflowPollingHandle[R]{workflowID: workflowStatus.ID, systemDB: dbosCtx.(*dbosContext).systemDB}, nil + return &workflowPollingHandle[R]{workflowID: workflowStatus.ID, dbosContext: dbosCtx}, nil } // Record child workflow relationship if this is a child workflow if isChildWorkflow { // Get the step ID that was used for generating the child workflow ID stepID := parentWorkflowState.stepCounter - childInput := recordChildWorkflowDBInput{ + childInput := RecordChildWorkflowDBInput{ parentWorkflowID: parentWorkflowState.workflowID, childWorkflowID: workflowStatus.ID, stepName: runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name(), // Will need to test this @@ -489,7 +498,7 @@ func runAsWorkflow[P any, R any](dbosCtx DBOSContext, fn WorkflowFunc[P, R], inp handle := &workflowHandle[R]{ workflowID: workflowStatus.ID, outcomeChan: outcomeChan, - systemDB: dbosCtx.(*dbosContext).systemDB, + dbosContext: dbosCtx, } // Create workflow state to track step execution @@ -530,9 +539,8 @@ func runAsWorkflow[P any, R any](dbosCtx DBOSContext, fn WorkflowFunc[P, R], inp /******* STEP FUNCTIONS *******/ /******************************/ -type StepFunc[P any, R any] func(ctx context.Context, input P) (R, error) - -type TypeErasedStepFunc func(ctx context.Context, input any) (any, error) +type GenericStepFunc[P any, R any] func(ctx context.Context, input P) (R, error) +type StepFunc func(ctx context.Context, input any) (any, error) type StepParams struct { MaxRetries int @@ -541,38 +549,38 @@ type StepParams struct { MaxInterval time.Duration } -// stepOption is a functional option for configuring step parameters -type stepOption func(*StepParams) +// StepOption is a functional option for configuring step parameters +type StepOption func(*StepParams) // WithStepMaxRetries sets the maximum number of retries for a step -func WithStepMaxRetries(maxRetries int) stepOption { +func WithStepMaxRetries(maxRetries int) StepOption { return func(p *StepParams) { p.MaxRetries = maxRetries } } // WithBackoffFactor sets the backoff factor for retries (multiplier for exponential backoff) -func WithBackoffFactor(backoffFactor float64) stepOption { +func WithBackoffFactor(backoffFactor float64) StepOption { return func(p *StepParams) { p.BackoffFactor = backoffFactor } } // WithBaseInterval sets the base delay for the first retry -func WithBaseInterval(baseInterval time.Duration) stepOption { +func WithBaseInterval(baseInterval time.Duration) StepOption { return func(p *StepParams) { p.BaseInterval = baseInterval } } // WithMaxInterval sets the maximum delay for retries -func WithMaxInterval(maxInterval time.Duration) stepOption { +func WithMaxInterval(maxInterval time.Duration) StepOption { return func(p *StepParams) { p.MaxInterval = maxInterval } } -func (e *dbosContext) RunAsStep(fn TypeErasedStepFunc, input any, stepName string, opts ...stepOption) (any, error) { +func (e *dbosContext) RunAsStep(fn StepFunc, input any, stepName string, opts ...StepOption) (any, error) { if fn == nil { return nil, newStepExecutionError("", "", "step function cannot be nil") } @@ -603,11 +611,7 @@ func (e *dbosContext) RunAsStep(fn TypeErasedStepFunc, input any, stepName strin stepID := wfState.NextStepID() // Check the step is cancelled, has already completed, or is called with a different name - recordedOutput, err := e.systemDB.CheckOperationExecution(e.ctx, checkOperationExecutionDBInput{ - workflowID: wfState.workflowID, - stepID: stepID, - stepName: stepName, - }) + recordedOutput, err := e.CheckOperationExecution(wfState.workflowID, stepID, stepName) if err != nil { return nil, newStepExecutionError(wfState.workflowID, stepName, fmt.Sprintf("checking operation execution: %v", err)) } @@ -621,6 +625,8 @@ func (e *dbosContext) RunAsStep(fn TypeErasedStepFunc, input any, stepName strin stepCounter: wfState.stepCounter, isWithinStep: true, } + + // Spawn a child DBOSContext with the step state stepCtx := e.WithValue(workflowStateKey, &stepState) stepOutput, stepError := fn(stepCtx, input) @@ -668,14 +674,14 @@ func (e *dbosContext) RunAsStep(fn TypeErasedStepFunc, input any, stepName strin } // Record the final result - dbInput := recordOperationResultDBInput{ + dbInput := RecordOperationResultDBInput{ workflowID: wfState.workflowID, stepName: stepName, stepID: stepID, err: stepError, output: stepOutput, } - recErr := e.systemDB.RecordOperationResult(e.ctx, dbInput) + recErr := e.RecordOperationResult(dbInput) if recErr != nil { return nil, newStepExecutionError(wfState.workflowID, stepName, fmt.Sprintf("recording step outcome: %v", recErr)) } @@ -683,7 +689,10 @@ func (e *dbosContext) RunAsStep(fn TypeErasedStepFunc, input any, stepName strin return stepOutput, stepError } -func RunAsStep[P any, R any](dbosCtx DBOSContext, fn StepFunc[P, R], input P, opts ...stepOption) (R, error) { +func RunAsStep[P any, R any](dbosCtx DBOSContext, fn GenericStepFunc[P, R], input P, opts ...StepOption) (R, error) { + if dbosCtx == nil { + return *new(R), errors.New("dbosCtx cannot be nil") + } if fn == nil { return *new(R), newStepExecutionError("", "", "step function cannot be nil") } @@ -731,6 +740,9 @@ func (e *dbosContext) Send(input WorkflowSendInputInternal) error { // Send sends a message to another workflow. // Send automatically registers the type of R for gob encoding func Send[R any](dbosCtx DBOSContext, input WorkflowSendInput[R]) error { + if dbosCtx == nil { + return errors.New("dbosCtx cannot be nil") + } var typedMessage R gob.Register(typedMessage) return dbosCtx.Send(WorkflowSendInputInternal{ @@ -750,6 +762,9 @@ func (e *dbosContext) Recv(input WorkflowRecvInput) (any, error) { } func Recv[R any](dbosCtx DBOSContext, input WorkflowRecvInput) (R, error) { + if dbosCtx == nil { + return *new(R), errors.New("dbosCtx cannot be nil") + } msg, err := dbosCtx.Recv(input) if err != nil { return *new(R), err @@ -779,6 +794,9 @@ func (e *dbosContext) SetEvent(input WorkflowSetEventInputInternal) error { // The event is a key value pair // SetEvent automatically registers the type of R for gob encoding func SetEvent[R any](dbosCtx DBOSContext, input WorkflowSetEventInput[R]) error { + if dbosCtx == nil { + return errors.New("dbosCtx cannot be nil") + } var typedMessage R gob.Register(typedMessage) return dbosCtx.SetEvent(WorkflowSetEventInputInternal{ @@ -798,6 +816,9 @@ func (e *dbosContext) GetEvent(input WorkflowGetEventInput) (any, error) { } func GetEvent[R any](dbosCtx DBOSContext, input WorkflowGetEventInput) (R, error) { + if dbosCtx == nil { + return *new(R), errors.New("dbosCtx cannot be nil") + } value, err := dbosCtx.GetEvent(input) if err != nil { return *new(R), err @@ -818,6 +839,9 @@ func (e *dbosContext) Sleep(duration time.Duration) (time.Duration, error) { } func Sleep(dbosCtx DBOSContext, duration time.Duration) (time.Duration, error) { + if dbosCtx == nil { + return 0, errors.New("dbosCtx cannot be nil") + } return dbosCtx.Sleep(duration) } @@ -834,13 +858,20 @@ func (e *dbosContext) GetWorkflowID() (string, error) { return wfState.workflowID, nil } +func (e *dbosContext) AwaitWorkflowResult(workflowID string) (any, error) { + return e.systemDB.AwaitWorkflowResult(e.ctx, workflowID) +} + func (e *dbosContext) RetrieveWorkflow(workflowIDs []string) ([]WorkflowStatus, error) { - return e.systemDB.ListWorkflows(e.ctx, listWorkflowsDBInput{ + return e.systemDB.ListWorkflows(e.ctx, ListWorkflowsDBInput{ workflowIDs: workflowIDs, }) } func RetrieveWorkflow[R any](dbosCtx DBOSContext, workflowID string) (workflowPollingHandle[R], error) { + if dbosCtx == nil { + return workflowPollingHandle[R]{}, errors.New("dbosCtx cannot be nil") + } workflowStatus, err := dbosCtx.RetrieveWorkflow([]string{workflowID}) if err != nil { return workflowPollingHandle[R]{}, fmt.Errorf("failed to retrieve workflow status: %w", err) @@ -848,30 +879,51 @@ func RetrieveWorkflow[R any](dbosCtx DBOSContext, workflowID string) (workflowPo if len(workflowStatus) == 0 { return workflowPollingHandle[R]{}, newNonExistentWorkflowError(workflowID) } - return workflowPollingHandle[R]{workflowID: workflowID, systemDB: dbosCtx.(*dbosContext).systemDB}, nil + return workflowPollingHandle[R]{workflowID: workflowID, dbosContext: dbosCtx}, nil } +// "private" workflow management (used within runAsWorkflow, runAsStep, etc.) + func (e *dbosContext) CheckChildWorkflow(parentWorkflowID string, stepCounter int) (*string, error) { return e.systemDB.CheckChildWorkflow(e.ctx, parentWorkflowID, stepCounter) } -func (e *dbosContext) InsertWorkflowStatus(status WorkflowStatus, maxRetries int) (*insertWorkflowResult, error) { +func (e *dbosContext) InsertWorkflowStatus(status WorkflowStatus, maxRetries int) (*InsertWorkflowResult, error) { return e.systemDB.InsertWorkflowStatus(e.ctx, insertWorkflowStatusDBInput{ status: status, maxRetries: maxRetries, }) } -func (e *dbosContext) RecordChildWorkflow(input recordChildWorkflowDBInput) error { +func (e *dbosContext) RecordChildWorkflow(input RecordChildWorkflowDBInput) error { return e.systemDB.RecordChildWorkflow(e.ctx, input) } +func (e *dbosContext) RecordOperationResult(input RecordOperationResultDBInput) error { + return e.systemDB.RecordOperationResult(e.ctx, input) +} + +func (e *dbosContext) CheckOperationExecution(workflowID string, stepID int, stepName string) (*RecordedResult, error) { + return e.systemDB.CheckOperationExecution(e.ctx, checkOperationExecutionDBInput{ + workflowID: workflowID, + stepID: stepID, + stepName: stepName, + }) +} + func (e *dbosContext) UpdateWorkflowOutcome(workflowID string, status WorkflowStatusType, err error, output any) error { return e.systemDB.UpdateWorkflowOutcome(e.ctx, updateWorkflowOutcomeDBInput{ workflowID: workflowID, status: status, err: err, output: output, - tx: nil, // No explicit transaction for this interface method }) } + +func (e *dbosContext) RecordChildGetResult(input RecordChildGetResultDBInput) error { + return e.systemDB.RecordChildGetResult(e.ctx, input) +} + +func (e *dbosContext) ListWorkflows(input ListWorkflowsDBInput) ([]WorkflowStatus, error) { + return e.systemDB.ListWorkflows(e.ctx, input) +} diff --git a/dbos/workflows_test.go b/dbos/workflows_test.go index 1fd2dcdc..b17c70f4 100644 --- a/dbos/workflows_test.go +++ b/dbos/workflows_test.go @@ -116,7 +116,7 @@ func TestWorkflowsWrapping(t *testing.T) { type testCase struct { name string - workflowFunc func(context.Context, string, ...workflowOption) (any, error) + workflowFunc func(context.Context, string, ...WorkflowOption) (any, error) input string expectedResult any expectError bool @@ -126,7 +126,7 @@ func TestWorkflowsWrapping(t *testing.T) { tests := []testCase{ { name: "SimpleWorkflow", - workflowFunc: func(ctx context.Context, input string, opts ...workflowOption) (any, error) { + workflowFunc: func(ctx context.Context, input string, opts ...WorkflowOption) (any, error) { handle, err := simpleWf(ctx, input, opts...) if err != nil { return nil, err @@ -148,7 +148,7 @@ func TestWorkflowsWrapping(t *testing.T) { }, { name: "SimpleWorkflowError", - workflowFunc: func(ctx context.Context, input string, opts ...workflowOption) (any, error) { + workflowFunc: func(ctx context.Context, input string, opts ...WorkflowOption) (any, error) { handle, err := simpleWfError(ctx, input, opts...) if err != nil { return nil, err @@ -161,7 +161,7 @@ func TestWorkflowsWrapping(t *testing.T) { }, { name: "SimpleWorkflowWithStep", - workflowFunc: func(ctx context.Context, input string, opts ...workflowOption) (any, error) { + workflowFunc: func(ctx context.Context, input string, opts ...WorkflowOption) (any, error) { handle, err := simpleWfWithStep(ctx, input, opts...) if err != nil { return nil, err @@ -174,7 +174,7 @@ func TestWorkflowsWrapping(t *testing.T) { }, { name: "SimpleWorkflowStruct", - workflowFunc: func(ctx context.Context, input string, opts ...workflowOption) (any, error) { + workflowFunc: func(ctx context.Context, input string, opts ...WorkflowOption) (any, error) { handle, err := simpleWfStruct(ctx, input, opts...) if err != nil { return nil, err @@ -187,7 +187,7 @@ func TestWorkflowsWrapping(t *testing.T) { }, { name: "ValueReceiverWorkflow", - workflowFunc: func(ctx context.Context, input string, opts ...workflowOption) (any, error) { + workflowFunc: func(ctx context.Context, input string, opts ...WorkflowOption) (any, error) { handle, err := simpleWfValue(ctx, input, opts...) if err != nil { return nil, err @@ -200,7 +200,7 @@ func TestWorkflowsWrapping(t *testing.T) { }, { name: "interfaceMethodWorkflow", - workflowFunc: func(ctx context.Context, input string, opts ...workflowOption) (any, error) { + workflowFunc: func(ctx context.Context, input string, opts ...WorkflowOption) (any, error) { handle, err := simpleWfIface(ctx, input, opts...) if err != nil { return nil, err @@ -213,7 +213,7 @@ func TestWorkflowsWrapping(t *testing.T) { }, { name: "GenericWorkflow", - workflowFunc: func(ctx context.Context, input string, opts ...workflowOption) (any, error) { + workflowFunc: func(ctx context.Context, input string, opts ...WorkflowOption) (any, error) { // For generic workflow, we need to convert string to int for testing handle, err := wfInt(ctx, "42", opts...) // FIXME for now this returns a string because sys db accepts this if err != nil { @@ -227,7 +227,7 @@ func TestWorkflowsWrapping(t *testing.T) { }, { name: "ClosureWithCapturedState", - workflowFunc: func(ctx context.Context, input string, opts ...workflowOption) (any, error) { + workflowFunc: func(ctx context.Context, input string, opts ...WorkflowOption) (any, error) { handle, err := wfClose(ctx, input, opts...) if err != nil { return nil, err @@ -240,7 +240,7 @@ func TestWorkflowsWrapping(t *testing.T) { }, { name: "AnonymousClosure", - workflowFunc: func(ctx context.Context, input string, opts ...workflowOption) (any, error) { + workflowFunc: func(ctx context.Context, input string, opts ...WorkflowOption) (any, error) { handle, err := anonymousWf(ctx, input, opts...) if err != nil { return nil, err @@ -253,7 +253,7 @@ func TestWorkflowsWrapping(t *testing.T) { }, { name: "SimpleWorkflowWithStepError", - workflowFunc: func(ctx context.Context, input string, opts ...workflowOption) (any, error) { + workflowFunc: func(ctx context.Context, input string, opts ...WorkflowOption) (any, error) { handle, err := simpleWfWithStepError(ctx, input, opts...) if err != nil { return nil, err @@ -666,7 +666,7 @@ func TestWorkflowRecovery(t *testing.T) { } // Using ListWorkflows, retrieve the status of the workflow - workflows, err := dbos.systemDB.ListWorkflows(context.Background(), listWorkflowsDBInput{ + workflows, err := dbos.systemDB.ListWorkflows(context.Background(), ListWorkflowsDBInput{ workflowIDs: []string{handle1.GetWorkflowID()}, }) if err != nil { From 897ba1a6cc6fd6bc68b1981d61a26b3495fffa08 Mon Sep 17 00:00:00 2001 From: maxdml Date: Wed, 30 Jul 2025 18:13:34 -0700 Subject: [PATCH 05/30] cleanup interface + RunAsWorkflow --- dbos/dbos.go | 151 +++++++--------- dbos/queue.go | 9 +- dbos/queues_test.go | 2 +- dbos/recovery.go | 6 +- dbos/serialization_test.go | 6 +- dbos/system_database.go | 78 ++++---- dbos/workflow.go | 358 +++++++++++++++++++------------------ 7 files changed, 291 insertions(+), 319 deletions(-) diff --git a/dbos/dbos.go b/dbos/dbos.go index 2d01b55e..31f3fbaf 100644 --- a/dbos/dbos.go +++ b/dbos/dbos.go @@ -66,43 +66,28 @@ type DBOSContext interface { // Context Lifecycle Launch() error Shutdown() - WithValue(key, val any) DBOSContext // Workflow registration RegisterWorkflow(fqn string, fn WrappedWorkflowFunc, maxRetries int) RegisterScheduledWorkflow(fqn string, fn WrappedWorkflowFunc, cronSchedule string, maxRetries int) // Workflow operations - RunAsStep(fn StepFunc, input any, stepName string, opts ...StepOption) (any, error) - Send(input WorkflowSendInputInternal) error - Recv(input WorkflowRecvInput) (any, error) - SetEvent(input WorkflowSetEventInputInternal) error - GetEvent(input WorkflowGetEventInput) (any, error) + RunAsStep(_ DBOSContext, fn StepFunc, input any, stepName string, opts ...StepOption) (any, error) + RunAsWorkflow(_ DBOSContext, fn WorkflowFunc, input any, opts ...WorkflowOption) (WorkflowHandle[any], error) + Send(_ DBOSContext, input WorkflowSendInputInternal) error + Recv(_ DBOSContext, input WorkflowRecvInput) (any, error) + SetEvent(_ DBOSContext, input WorkflowSetEventInput) error + GetEvent(_ DBOSContext, input WorkflowGetEventInput) (any, error) Sleep(duration time.Duration) (time.Duration, error) + GetWorkflowID() (string, error) // Workflow management - InsertWorkflowStatus(status WorkflowStatus, maxRetries int) (*InsertWorkflowResult, error) - UpdateWorkflowOutcome(workflowID string, status WorkflowStatusType, err error, output any) error - RetrieveWorkflow(workflowIDs []string) ([]WorkflowStatus, error) - ListWorkflows(input ListWorkflowsDBInput) ([]WorkflowStatus, error) - CheckChildWorkflow(parentWorkflowID string, stepCounter int) (*string, error) - RecordChildWorkflow(input RecordChildWorkflowDBInput) error - CheckOperationExecution(workflowID string, stepID int, stepName string) (*RecordedResult, error) - RecordOperationResult(input RecordOperationResultDBInput) error - RecordChildGetResult(input RecordChildGetResultDBInput) error - - // Context operations - GetWorkflowID() (string, error) - AwaitWorkflowResult(workflowID string) (any, error) + RetrieveWorkflow(_ DBOSContext, workflowID string) (WorkflowHandle[any], error) // Accessors - GetWorkflowScheduler() *cron.Cron GetApplicationVersion() string - GetSystemDB() SystemDatabase - GetContext() context.Context GetExecutorID() string GetApplicationID() string - GetWorkflowWg() *sync.WaitGroup } type dbosContext struct { @@ -133,65 +118,53 @@ type dbosContext struct { } // Implement contex.Context interface methods -func (e *dbosContext) Deadline() (deadline time.Time, ok bool) { - return e.ctx.Deadline() +func (c *dbosContext) Deadline() (deadline time.Time, ok bool) { + return c.ctx.Deadline() } -func (e *dbosContext) Done() <-chan struct{} { - return e.ctx.Done() +func (c *dbosContext) Done() <-chan struct{} { + return c.ctx.Done() } -func (e *dbosContext) Err() error { - return e.ctx.Err() +func (c *dbosContext) Err() error { + return c.ctx.Err() } -func (e *dbosContext) Value(key any) any { - return e.ctx.Value(key) +func (c *dbosContext) Value(key any) any { + return c.ctx.Value(key) } // Create a new context // This is intended for workflow contexts and step contexts // Hence we only set the relevant fields -func (e *dbosContext) WithValue(key, val any) DBOSContext { +func (c *dbosContext) withValue(key, val any) DBOSContext { return &dbosContext{ - ctx: context.WithValue(e.ctx, key, val), - systemDB: e.systemDB, - applicationVersion: e.applicationVersion, - executorID: e.executorID, - applicationID: e.applicationID, - workflowsWg: e.workflowsWg, + ctx: context.WithValue(c.ctx, key, val), + systemDB: c.systemDB, + applicationVersion: c.applicationVersion, + executorID: c.executorID, + applicationID: c.applicationID, + workflowsWg: c.workflowsWg, } } -func (e *dbosContext) GetContext() context.Context { - return e.ctx -} - -func (e *dbosContext) GetWorkflowScheduler() *cron.Cron { - if e.workflowScheduler == nil { - e.workflowScheduler = cron.New(cron.WithSeconds()) +func (c *dbosContext) getWorkflowScheduler() *cron.Cron { + if c.workflowScheduler == nil { + c.workflowScheduler = cron.New(cron.WithSeconds()) } - return e.workflowScheduler -} - -func (e *dbosContext) GetApplicationVersion() string { - return e.applicationVersion -} - -func (e *dbosContext) GetSystemDB() SystemDatabase { - return e.systemDB + return c.workflowScheduler } -func (e *dbosContext) GetExecutorID() string { - return e.executorID +func (c *dbosContext) GetApplicationVersion() string { + return c.applicationVersion } -func (e *dbosContext) GetApplicationID() string { - return e.applicationID +func (c *dbosContext) GetExecutorID() string { + return c.executorID } -func (e *dbosContext) GetWorkflowWg() *sync.WaitGroup { - return e.workflowsWg +func (c *dbosContext) GetApplicationID() string { + return c.applicationID } func NewDBOSContext(inputConfig Config) (DBOSContext, error) { @@ -202,7 +175,7 @@ func NewDBOSContext(inputConfig Config) (DBOSContext, error) { workflowRegMutex: &sync.RWMutex{}, } - // Load & process the configuration + // Load and process the configuration config, err := processConfig(&inputConfig) if err != nil { return nil, newInitializationError(err.Error()) @@ -242,44 +215,44 @@ func NewDBOSContext(inputConfig Config) (DBOSContext, error) { return initExecutor, nil } -func (e *dbosContext) Launch() error { +func (c *dbosContext) Launch() error { // Start the system database - e.systemDB.Launch(context.Background()) + c.systemDB.Launch(context.Background()) // Start the admin server if configured - if e.config.AdminServer { - adminServer := newAdminServer(e, _DEFAULT_ADMIN_SERVER_PORT) + if c.config.AdminServer { + adminServer := newAdminServer(c, _DEFAULT_ADMIN_SERVER_PORT) err := adminServer.Start() if err != nil { logger.Error("Failed to start admin server", "error", err) return newInitializationError(fmt.Sprintf("failed to start admin server: %v", err)) } logger.Info("Admin server started", "port", _DEFAULT_ADMIN_SERVER_PORT) - e.adminServer = adminServer + c.adminServer = adminServer } // Create context with cancel function for queue runner // FIXME: cancellation now has to go through the DBOSContext - ctx, cancel := context.WithCancel(e.GetContext()) - e.queueRunnerCtx = ctx - e.queueRunnerCancelFunc = cancel - e.queueRunnerDone = make(chan struct{}) + ctx, cancel := context.WithCancel(c.ctx) + c.queueRunnerCtx = ctx + c.queueRunnerCancelFunc = cancel + c.queueRunnerDone = make(chan struct{}) // Start the queue runner in a goroutine go func() { - defer close(e.queueRunnerDone) - queueRunner(e) + defer close(c.queueRunnerDone) + queueRunner(c) }() logger.Info("Queue runner started") // Start the workflow scheduler if it has been initialized - if e.workflowScheduler != nil { - e.workflowScheduler.Start() + if c.workflowScheduler != nil { + c.workflowScheduler.Start() logger.Info("Workflow scheduler started") } // Run a round of recovery on the local executor - recoveryHandles, err := recoverPendingWorkflows(e, []string{e.executorID}) + recoveryHandles, err := recoverPendingWorkflows(c, []string{c.executorID}) if err != nil { return newInitializationError(fmt.Sprintf("failed to recover pending workflows during launch: %v", err)) } @@ -287,24 +260,24 @@ func (e *dbosContext) Launch() error { logger.Info("Recovered pending workflows", "count", len(recoveryHandles)) } - logger.Info("DBOS initialized", "app_version", e.applicationVersion, "executor_id", e.executorID) + logger.Info("DBOS initialized", "app_version", c.applicationVersion, "executor_id", c.executorID) return nil } -func (e *dbosContext) Shutdown() { +func (c *dbosContext) Shutdown() { // Wait for all workflows to finish - e.workflowsWg.Wait() + c.workflowsWg.Wait() // Cancel the context to stop the queue runner - if e.queueRunnerCancelFunc != nil { - e.queueRunnerCancelFunc() + if c.queueRunnerCancelFunc != nil { + c.queueRunnerCancelFunc() // Wait for queue runner to finish - <-e.queueRunnerDone + <-c.queueRunnerDone getLogger().Info("Queue runner stopped") } - if e.workflowScheduler != nil { - ctx := e.workflowScheduler.Stop() + if c.workflowScheduler != nil { + ctx := c.workflowScheduler.Stop() // Wait for all running jobs to complete with 5-second timeout timeoutCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() @@ -317,19 +290,19 @@ func (e *dbosContext) Shutdown() { } } - if e.systemDB != nil { - e.systemDB.Shutdown() - e.systemDB = nil + if c.systemDB != nil { + c.systemDB.Shutdown() + c.systemDB = nil } - if e.adminServer != nil { - err := e.adminServer.Shutdown() + if c.adminServer != nil { + err := c.adminServer.Shutdown() if err != nil { getLogger().Error("Failed to shutdown admin server", "error", err) } else { getLogger().Info("Admin server shutdown complete") } - e.adminServer = nil + c.adminServer = nil } if logger != nil { diff --git a/dbos/queue.go b/dbos/queue.go index 43a39f32..070cbd8a 100644 --- a/dbos/queue.go +++ b/dbos/queue.go @@ -2,7 +2,6 @@ package dbos import ( "bytes" - "context" "encoding/base64" "encoding/gob" "math" @@ -119,7 +118,7 @@ func queueRunner(executor *dbosContext) { for queueName, queue := range workflowQueueRegistry { getLogger().Debug("Processing queue", "queue_name", queueName) // Call DequeueWorkflows for each queue - dequeuedWorkflows, err := executor.systemDB.DequeueWorkflows(executor.GetContext(), queue, executor.executorID, executor.applicationVersion) + dequeuedWorkflows, err := executor.systemDB.DequeueWorkflows(executor.ctx, queue, executor.executorID, executor.applicationVersion) if err != nil { if pgErr, ok := err.(*pgconn.PgError); ok { switch pgErr.Code { @@ -162,9 +161,7 @@ func queueRunner(executor *dbosContext) { } } - // Create a workflow context from the executor context - workflowCtx := executor.WithValue(context.Background(), nil) - _, err := registeredWorkflow.wrappedFunction(workflowCtx, input, WithWorkflowID(workflow.id)) + _, err := registeredWorkflow.wrappedFunction(executor, input, WithWorkflowID(workflow.id)) if err != nil { getLogger().Error("Error running queued workflow", "error", err) } @@ -186,7 +183,7 @@ func queueRunner(executor *dbosContext) { // Sleep with jittered interval, but allow early exit on context cancellation select { - case <-executor.GetContext().Done(): + case <-executor.ctx.Done(): getLogger().Info("Queue runner stopping due to context cancellation") return case <-time.After(sleepDuration): diff --git a/dbos/queues_test.go b/dbos/queues_test.go index 682e537d..9292904b 100644 --- a/dbos/queues_test.go +++ b/dbos/queues_test.go @@ -55,7 +55,7 @@ func TestWorkflowQueues(t *testing.T) { queueWf := RegisterWorkflow(executor, queueWorkflow) // Create workflow with child that can call the main workflow - queueWfWithChild := RegisterWorkflow(executor, func(ctx context.Context, input string) (string, error) { + queueWfWithChild := RegisterWorkflow[string, string](executor, func(ctx DBOSContext, input string) (string, error) { // Start a child workflow childHandle, err := queueWf(ctx, input+"-child") if err != nil { diff --git a/dbos/recovery.go b/dbos/recovery.go index b6e543ce..67f035e9 100644 --- a/dbos/recovery.go +++ b/dbos/recovery.go @@ -7,7 +7,7 @@ import ( func recoverPendingWorkflows(dbosCtx *dbosContext, executorIDs []string) ([]WorkflowHandle[any], error) { workflowHandles := make([]WorkflowHandle[any], 0) // List pending workflows for the executors - pendingWorkflows, err := dbosCtx.systemDB.ListWorkflows(dbosCtx.GetContext(), ListWorkflowsDBInput{ + pendingWorkflows, err := dbosCtx.systemDB.ListWorkflows(dbosCtx.ctx, listWorkflowsDBInput{ status: []WorkflowStatusType{WorkflowStatusPending}, executorIDs: executorIDs, applicationVersion: dbosCtx.applicationVersion, @@ -26,7 +26,7 @@ func recoverPendingWorkflows(dbosCtx *dbosContext, executorIDs []string) ([]Work // fmt.Println("Recovering workflow:", workflow.ID, "Name:", workflow.Name, "Input:", workflow.Input, "QueueName:", workflow.QueueName) if workflow.QueueName != "" { - cleared, err := dbosCtx.systemDB.ClearQueueAssignment(dbosCtx.GetContext(), workflow.ID) + cleared, err := dbosCtx.systemDB.ClearQueueAssignment(dbosCtx.ctx, workflow.ID) if err != nil { getLogger().Error("Error clearing queue assignment for workflow", "workflow_id", workflow.ID, "name", workflow.Name, "error", err) continue @@ -56,7 +56,7 @@ func recoverPendingWorkflows(dbosCtx *dbosContext, executorIDs []string) ([]Work } // Create a workflow context from the executor context - workflowCtx := dbosCtx.WithValue(dbosCtx.GetContext(), nil) + workflowCtx := dbosCtx.withValue(dbosCtx.ctx, nil) handle, err := registeredWorkflow.wrappedFunction(workflowCtx, workflow.Input, opts...) if err != nil { return nil, err diff --git a/dbos/serialization_test.go b/dbos/serialization_test.go index 22900510..a171053c 100644 --- a/dbos/serialization_test.go +++ b/dbos/serialization_test.go @@ -99,7 +99,7 @@ func TestWorkflowEncoding(t *testing.T) { } // Test results from ListWorkflows - workflows, err := dbos.systemDB.ListWorkflows(context.Background(), ListWorkflowsDBInput{ + workflows, err := dbos.systemDB.ListWorkflows(context.Background(), listWorkflowsDBInput{ workflowIDs: []string{directHandle.GetWorkflowID()}, }) if err != nil { @@ -216,7 +216,7 @@ func TestWorkflowEncoding(t *testing.T) { } // Test results from ListWorkflows - workflows, err := dbos.systemDB.ListWorkflows(context.Background(), ListWorkflowsDBInput{ + workflows, err := dbos.systemDB.ListWorkflows(context.Background(), listWorkflowsDBInput{ workflowIDs: []string{directHandle.GetWorkflowID()}, }) if err != nil { @@ -313,7 +313,7 @@ func setEventUserDefinedTypeWorkflow(ctx context.Context, input string) (string, }, } - err := SetEvent(ctx, dbos, WorkflowSetEventInput[UserDefinedEventData]{Key: input, Message: eventData}) + err := SetEvent(ctx, dbos, WorkflowSetEventInputGeneric[UserDefinedEventData]{Key: input, Message: eventData}) if err != nil { return "", err } diff --git a/dbos/system_database.go b/dbos/system_database.go index c2177bfb..05f10c41 100644 --- a/dbos/system_database.go +++ b/dbos/system_database.go @@ -29,25 +29,25 @@ type SystemDatabase interface { ResetSystemDB(ctx context.Context) error // Workflows - InsertWorkflowStatus(ctx context.Context, input insertWorkflowStatusDBInput) (*InsertWorkflowResult, error) - ListWorkflows(ctx context.Context, input ListWorkflowsDBInput) ([]WorkflowStatus, error) + InsertWorkflowStatus(ctx context.Context, input insertWorkflowStatusDBInput) (*insertWorkflowResult, error) + ListWorkflows(ctx context.Context, input listWorkflowsDBInput) ([]WorkflowStatus, error) UpdateWorkflowOutcome(ctx context.Context, input updateWorkflowOutcomeDBInput) error AwaitWorkflowResult(ctx context.Context, workflowID string) (any, error) // Child workflows - RecordChildWorkflow(ctx context.Context, input RecordChildWorkflowDBInput) error + RecordChildWorkflow(ctx context.Context, input recordChildWorkflowDBInput) error CheckChildWorkflow(ctx context.Context, workflowUUID string, functionID int) (*string, error) - RecordChildGetResult(ctx context.Context, input RecordChildGetResultDBInput) error + RecordChildGetResult(ctx context.Context, input recordChildGetResultDBInput) error // Steps - RecordOperationResult(ctx context.Context, input RecordOperationResultDBInput) error - CheckOperationExecution(ctx context.Context, input checkOperationExecutionDBInput) (*RecordedResult, error) - GetWorkflowSteps(ctx context.Context, workflowID string) ([]StepInfo, error) + RecordOperationResult(ctx context.Context, input recordOperationResultDBInput) error + CheckOperationExecution(ctx context.Context, input checkOperationExecutionDBInput) (*recordedResult, error) + GetWorkflowSteps(ctx context.Context, workflowID string) ([]stepInfo, error) // Communication (special steps) Send(ctx context.Context, input WorkflowSendInputInternal) error Recv(ctx context.Context, input WorkflowRecvInput) (any, error) - SetEvent(ctx context.Context, input WorkflowSetEventInputInternal) error + SetEvent(ctx context.Context, input WorkflowSetEventInput) error GetEvent(ctx context.Context, input WorkflowGetEventInput) (any, error) // Timers (special steps) @@ -233,25 +233,25 @@ func (s *systemDatabase) Shutdown() { /******* WORKFLOWS ********/ /*******************************/ -type InsertWorkflowResult struct { +type insertWorkflowResult struct { attempts int status WorkflowStatusType name string queueName *string workflowDeadlineEpochMs *int64 - tx pgx.Tx } type insertWorkflowStatusDBInput struct { status WorkflowStatus maxRetries int + tx pgx.Tx } -func (s *systemDatabase) InsertWorkflowStatus(ctx context.Context, input insertWorkflowStatusDBInput) (*InsertWorkflowResult, error) { - tx, err := s.pool.Begin(ctx) - if err != nil { - return nil, fmt.Errorf("failed to begin transaction: %w", err) +func (s *systemDatabase) InsertWorkflowStatus(ctx context.Context, input insertWorkflowStatusDBInput) (*insertWorkflowResult, error) { + if input.tx == nil { + return nil, errors.New("transaction is required for InsertWorkflowStatus") } + tx := input.tx // Set default values attempts := 1 @@ -313,9 +313,7 @@ func (s *systemDatabase) InsertWorkflowStatus(ctx context.Context, input insertW END RETURNING recovery_attempts, status, name, queue_name, workflow_deadline_epoch_ms` - result := InsertWorkflowResult{ - tx: tx, - } + result := insertWorkflowResult{} err = tx.QueryRow(ctx, query, input.status.ID, input.status.Status, @@ -386,7 +384,7 @@ func (s *systemDatabase) InsertWorkflowStatus(ctx context.Context, input insertW } // ListWorkflowsInput represents the input parameters for listing workflows -type ListWorkflowsDBInput struct { +type listWorkflowsDBInput struct { workflowName string queueName string workflowIDPrefix string @@ -404,7 +402,7 @@ type ListWorkflowsDBInput struct { } // ListWorkflows retrieves a list of workflows based on the provided filters -func (s *systemDatabase) ListWorkflows(ctx context.Context, input ListWorkflowsDBInput) ([]WorkflowStatus, error) { +func (s *systemDatabase) ListWorkflows(ctx context.Context, input listWorkflowsDBInput) ([]WorkflowStatus, error) { qb := newQueryBuilder() // Build the base query @@ -604,7 +602,7 @@ func (s *systemDatabase) CancelWorkflow(ctx context.Context, workflowID string) defer tx.Rollback(ctx) // Rollback if not committed // Check if workflow exists - listInput := ListWorkflowsDBInput{ + listInput := listWorkflowsDBInput{ workflowIDs: []string{workflowID}, tx: tx, } @@ -678,7 +676,7 @@ func (s *systemDatabase) AwaitWorkflowResult(ctx context.Context, workflowID str } } -type RecordOperationResultDBInput struct { +type recordOperationResultDBInput struct { workflowID string stepID int stepName string @@ -687,7 +685,7 @@ type RecordOperationResultDBInput struct { tx pgx.Tx } -func (s *systemDatabase) RecordOperationResult(ctx context.Context, input RecordOperationResultDBInput) error { +func (s *systemDatabase) RecordOperationResult(ctx context.Context, input recordOperationResultDBInput) error { query := `INSERT INTO dbos.operation_outputs (workflow_uuid, function_id, output, error, function_name) VALUES ($1, $2, $3, $4, $5) @@ -748,7 +746,7 @@ func (s *systemDatabase) RecordOperationResult(ctx context.Context, input Record /******* CHILD WORKFLOWS ********/ /*******************************/ -type RecordChildWorkflowDBInput struct { +type recordChildWorkflowDBInput struct { parentWorkflowID string childWorkflowID string stepID int @@ -756,7 +754,7 @@ type RecordChildWorkflowDBInput struct { tx pgx.Tx } -func (s *systemDatabase) RecordChildWorkflow(ctx context.Context, input RecordChildWorkflowDBInput) error { +func (s *systemDatabase) RecordChildWorkflow(ctx context.Context, input recordChildWorkflowDBInput) error { query := `INSERT INTO dbos.operation_outputs (workflow_uuid, function_id, function_name, child_workflow_id) VALUES ($1, $2, $3, $4)` @@ -814,7 +812,7 @@ func (s *systemDatabase) CheckChildWorkflow(ctx context.Context, workflowID stri return childWorkflowID, nil } -type RecordChildGetResultDBInput struct { +type recordChildGetResultDBInput struct { parentWorkflowID string childWorkflowID string stepID int @@ -822,7 +820,7 @@ type RecordChildGetResultDBInput struct { err error } -func (s *systemDatabase) RecordChildGetResult(ctx context.Context, input RecordChildGetResultDBInput) error { +func (s *systemDatabase) RecordChildGetResult(ctx context.Context, input recordChildGetResultDBInput) error { query := `INSERT INTO dbos.operation_outputs (workflow_uuid, function_id, function_name, output, error, child_workflow_id) VALUES ($1, $2, $3, $4, $5, $6) @@ -852,7 +850,7 @@ func (s *systemDatabase) RecordChildGetResult(ctx context.Context, input RecordC /******* STEPS ********/ /*******************************/ -type RecordedResult struct { +type recordedResult struct { output any err error } @@ -864,7 +862,7 @@ type checkOperationExecutionDBInput struct { tx pgx.Tx } -func (s *systemDatabase) CheckOperationExecution(ctx context.Context, input checkOperationExecutionDBInput) (*RecordedResult, error) { +func (s *systemDatabase) CheckOperationExecution(ctx context.Context, input checkOperationExecutionDBInput) (*recordedResult, error) { var tx pgx.Tx var err error @@ -932,14 +930,14 @@ func (s *systemDatabase) CheckOperationExecution(ctx context.Context, input chec if errorStr != nil && *errorStr != "" { recordedError = errors.New(*errorStr) } - result := &RecordedResult{ + result := &recordedResult{ output: output, err: recordedError, } return result, nil } -type StepInfo struct { +type stepInfo struct { FunctionID int FunctionName string Output any @@ -947,7 +945,7 @@ type StepInfo struct { ChildWorkflowID string } -func (s *systemDatabase) GetWorkflowSteps(ctx context.Context, workflowID string) ([]StepInfo, error) { +func (s *systemDatabase) GetWorkflowSteps(ctx context.Context, workflowID string) ([]stepInfo, error) { query := `SELECT function_id, function_name, output, error, child_workflow_id FROM dbos.operation_outputs WHERE workflow_uuid = $1` @@ -958,9 +956,9 @@ func (s *systemDatabase) GetWorkflowSteps(ctx context.Context, workflowID string } defer rows.Close() - var steps []StepInfo + var steps []stepInfo for rows.Next() { - var step StepInfo + var step stepInfo var outputString *string var errorString *string var childWorkflowID *string @@ -1052,7 +1050,7 @@ func (s *systemDatabase) Sleep(ctx context.Context, duration time.Duration) (tim endTime = time.Now().Add(duration) // Record the operation result with the calculated end time - recordInput := RecordOperationResultDBInput{ + recordInput := recordOperationResultDBInput{ workflowID: wfState.workflowID, stepID: stepID, stepName: functionName, @@ -1191,7 +1189,7 @@ func (s *systemDatabase) Send(ctx context.Context, input WorkflowSendInputIntern // Record the operation result if this is called within a workflow if isInWorkflow { - recordInput := RecordOperationResultDBInput{ + recordInput := recordOperationResultDBInput{ workflowID: wfState.workflowID, stepID: stepID, stepName: functionName, @@ -1341,7 +1339,7 @@ func (s *systemDatabase) Recv(ctx context.Context, input WorkflowRecvInput) (any } // Record the operation result - recordInput := RecordOperationResultDBInput{ + recordInput := recordOperationResultDBInput{ workflowID: destinationID, stepID: stepID, stepName: functionName, @@ -1360,12 +1358,12 @@ func (s *systemDatabase) Recv(ctx context.Context, input WorkflowRecvInput) (any return message, nil } -type WorkflowSetEventInputInternal struct { +type WorkflowSetEventInput struct { Key string Message any } -func (s *systemDatabase) SetEvent(ctx context.Context, input WorkflowSetEventInputInternal) error { +func (s *systemDatabase) SetEvent(ctx context.Context, input WorkflowSetEventInput) error { functionName := "DBOS.setEvent" // Get workflow state from context @@ -1420,7 +1418,7 @@ func (s *systemDatabase) SetEvent(ctx context.Context, input WorkflowSetEventInp } // Record the operation result - recordInput := RecordOperationResultDBInput{ + recordInput := recordOperationResultDBInput{ workflowID: wfState.workflowID, stepID: stepID, stepName: functionName, @@ -1543,7 +1541,7 @@ func (s *systemDatabase) GetEvent(ctx context.Context, input WorkflowGetEventInp // Record the operation result if this is called within a workflow if isInWorkflow { - recordInput := RecordOperationResultDBInput{ + recordInput := recordOperationResultDBInput{ workflowID: wfState.workflowID, stepID: stepID, stepName: functionName, diff --git a/dbos/workflow.go b/dbos/workflow.go index abf0d9dd..a207e736 100644 --- a/dbos/workflow.go +++ b/dbos/workflow.go @@ -82,6 +82,24 @@ type WorkflowHandle[R any] interface { GetWorkflowID() string } +type workflowHandleInternal struct { + workflowID string + workflowStatus WorkflowStatus +} + +// unimplemented +func (h *workflowHandleInternal) GetResult() (any, error) { + return nil, nil +} + +func (h *workflowHandleInternal) GetStatus() (WorkflowStatus, error) { + return h.workflowStatus, nil +} + +func (h *workflowHandleInternal) GetWorkflowID() string { + return h.workflowID +} + // workflowHandle is a concrete implementation of WorkflowHandle type workflowHandle[R any] struct { workflowID string @@ -97,21 +115,21 @@ func (h *workflowHandle[R]) GetResult() (R, error) { return *new(R), errors.New("workflow result channel is already closed. Did you call GetResult() twice on the same workflow handle?") } // If we are calling GetResult inside a workflow, record the result as a step result - parentWorkflowState, ok := h.dbosContext.GetContext().Value(workflowStateKey).(*workflowState) + parentWorkflowState, ok := h.dbosContext.(*dbosContext).ctx.Value(workflowStateKey).(*workflowState) isChildWorkflow := ok && parentWorkflowState != nil if isChildWorkflow { encodedOutput, encErr := serialize(outcome.result) if encErr != nil { return *new(R), newWorkflowExecutionError(parentWorkflowState.workflowID, fmt.Sprintf("serializing child workflow result: %v", encErr)) } - recordGetResultInput := RecordChildGetResultDBInput{ + recordGetResultInput := recordChildGetResultDBInput{ parentWorkflowID: parentWorkflowState.workflowID, childWorkflowID: h.workflowID, stepID: parentWorkflowState.NextStepID(), output: encodedOutput, err: outcome.err, } - recordResultErr := h.dbosContext.RecordChildGetResult(recordGetResultInput) + recordResultErr := h.dbosContext.(*dbosContext).systemDB.RecordChildGetResult(h.dbosContext.(*dbosContext).ctx, recordGetResultInput) if recordResultErr != nil { getLogger().Error("failed to record get result", "error", recordResultErr) return *new(R), newWorkflowExecutionError(parentWorkflowState.workflowID, fmt.Sprintf("recording child workflow result: %v", recordResultErr)) @@ -122,7 +140,7 @@ func (h *workflowHandle[R]) GetResult() (R, error) { // GetStatus returns the current status of the workflow from the database func (h *workflowHandle[R]) GetStatus() (WorkflowStatus, error) { - workflowStatuses, err := h.dbosContext.ListWorkflows(ListWorkflowsDBInput{ + workflowStatuses, err := h.dbosContext.(*dbosContext).systemDB.ListWorkflows(h.dbosContext.(*dbosContext).ctx, listWorkflowsDBInput{ workflowIDs: []string{h.workflowID}, }) if err != nil { @@ -145,7 +163,7 @@ type workflowPollingHandle[R any] struct { func (h *workflowPollingHandle[R]) GetResult() (R, error) { ctx := context.Background() - result, err := h.dbosContext.AwaitWorkflowResult(h.workflowID) + result, err := h.dbosContext.(*dbosContext).systemDB.AwaitWorkflowResult(h.dbosContext.(*dbosContext).ctx, h.workflowID) if result != nil { typedResult, ok := result.(R) if !ok { @@ -160,14 +178,14 @@ func (h *workflowPollingHandle[R]) GetResult() (R, error) { if encErr != nil { return *new(R), newWorkflowExecutionError(parentWorkflowState.workflowID, fmt.Sprintf("serializing child workflow result: %v", encErr)) } - recordGetResultInput := RecordChildGetResultDBInput{ + recordGetResultInput := recordChildGetResultDBInput{ parentWorkflowID: parentWorkflowState.workflowID, childWorkflowID: h.workflowID, stepID: parentWorkflowState.NextStepID(), output: encodedOutput, err: err, } - recordResultErr := h.dbosContext.RecordChildGetResult(recordGetResultInput) + recordResultErr := h.dbosContext.(*dbosContext).systemDB.RecordChildGetResult(h.dbosContext.(*dbosContext).ctx, recordGetResultInput) if recordResultErr != nil { // XXX do we want to fail this? getLogger().Error("failed to record get result", "error", recordResultErr) @@ -180,7 +198,7 @@ func (h *workflowPollingHandle[R]) GetResult() (R, error) { // GetStatus returns the current status of the workflow from the database func (h *workflowPollingHandle[R]) GetStatus() (WorkflowStatus, error) { - workflowStatuses, err := h.dbosContext.ListWorkflows(ListWorkflowsDBInput{ + workflowStatuses, err := h.dbosContext.(*dbosContext).systemDB.ListWorkflows(h.dbosContext.(*dbosContext).ctx, listWorkflowsDBInput{ workflowIDs: []string{h.workflowID}, }) if err != nil { @@ -208,38 +226,38 @@ type workflowRegistryEntry struct { } // Register adds a workflow function to the registry (thread-safe, only once per name) -func (e *dbosContext) RegisterWorkflow(fqn string, fn WrappedWorkflowFunc, maxRetries int) { - e.workflowRegMutex.Lock() - defer e.workflowRegMutex.Unlock() +func (c *dbosContext) RegisterWorkflow(fqn string, fn WrappedWorkflowFunc, maxRetries int) { + c.workflowRegMutex.Lock() + defer c.workflowRegMutex.Unlock() - if _, exists := e.workflowRegistry[fqn]; exists { + if _, exists := c.workflowRegistry[fqn]; exists { getLogger().Error("workflow function already registered", "fqn", fqn) panic(newConflictingRegistrationError(fqn)) } - e.workflowRegistry[fqn] = workflowRegistryEntry{ + c.workflowRegistry[fqn] = workflowRegistryEntry{ wrappedFunction: fn, maxRetries: maxRetries, } } -func (e *dbosContext) RegisterScheduledWorkflow(fqn string, fn WrappedWorkflowFunc, cronSchedule string, maxRetries int) { - e.GetWorkflowScheduler().Start() +func (c *dbosContext) RegisterScheduledWorkflow(fqn string, fn WrappedWorkflowFunc, cronSchedule string, maxRetries int) { + c.getWorkflowScheduler().Start() var entryID cron.EntryID - entryID, err := e.GetWorkflowScheduler().AddFunc(cronSchedule, func() { + entryID, err := c.getWorkflowScheduler().AddFunc(cronSchedule, func() { // Execute the workflow on the cron schedule once DBOS is launched - if e == nil { + if c == nil { return } // Get the scheduled time from the cron entry - entry := e.GetWorkflowScheduler().Entry(entryID) + entry := c.getWorkflowScheduler().Entry(entryID) scheduledTime := entry.Prev if scheduledTime.IsZero() { // Use Next if Prev is not set, which will only happen for the first run scheduledTime = entry.Next } wfID := fmt.Sprintf("sched-%s-%s", fqn, scheduledTime) // XXX we can rethink the format - fn(e, scheduledTime, WithWorkflowID(wfID), WithQueue(_DBOS_INTERNAL_QUEUE_NAME)) + fn(c, scheduledTime, WithWorkflowID(wfID), WithQueue(_DBOS_INTERNAL_QUEUE_NAME)) }) if err != nil { panic(fmt.Sprintf("failed to register scheduled workflow: %v", err)) @@ -271,12 +289,12 @@ func WithSchedule(schedule string) workflowRegistrationOption { } } -// RegisterWorkflow wraps the provided function as a durable workflow and registers it with the provided DBOSContext workflow registry +// RegisterWorkflow registers the provided function as a durable workflow with the provided DBOSContext workflow registry // If the workflow is a scheduled workflow (determined by the presence of a cron schedule), it will also register a cron job to execute it // RegisterWorkflow is generically typed, allowing us to register the workflow input and output types for gob encoding // The registered workflow is wrapped in a typed-erased wrapper which performs runtime type checks and conversions -// RegisterWorkflow returns the statically typed wrapped function. The DBOSContext registry holds the typed-erased version -func RegisterWorkflow[P any, R any](dbosCtx DBOSContext, fn GenericWorkflowFunc[P, R], opts ...workflowRegistrationOption) GenericWrappedWorkflowFunc[P, R] { +// To execute the workflow, use DBOSContext.RunAsWorkflow +func RegisterWorkflow[P any, R any](dbosCtx DBOSContext, fn GenericWorkflowFunc[P, R], opts ...workflowRegistrationOption) { if dbosCtx == nil { panic("dbosCtx cannot be nil") } @@ -299,12 +317,6 @@ func RegisterWorkflow[P any, R any](dbosCtx DBOSContext, fn GenericWorkflowFunc[ gob.Register(p) gob.Register(r) - // Wrap the function in a durable workflow - wrappedFunction := GenericWrappedWorkflowFunc[P, R](func(ctx DBOSContext, workflowInput P, opts ...WorkflowOption) (WorkflowHandle[R], error) { - opts = append(opts, WithWorkflowMaxRetries(registrationParams.maxRetries)) - return RunAsWorkflow(ctx, fn, workflowInput, opts...) - }) - // Register a type-erased version of the durable workflow for recovery typeErasedWrapper := func(ctx DBOSContext, input any, opts ...WorkflowOption) (WorkflowHandle[any], error) { typedInput, ok := input.(P) @@ -312,7 +324,8 @@ func RegisterWorkflow[P any, R any](dbosCtx DBOSContext, fn GenericWorkflowFunc[ return nil, newWorkflowUnexpectedInputType(fqn, fmt.Sprintf("%T", typedInput), fmt.Sprintf("%T", input)) } - handle, err := wrappedFunction(ctx, typedInput, opts...) + opts = append(opts, WithWorkflowMaxRetries(registrationParams.maxRetries)) + handle, err := RunAsWorkflow(ctx, fn, typedInput, opts...) if err != nil { return nil, err } @@ -327,8 +340,6 @@ func RegisterWorkflow[P any, R any](dbosCtx DBOSContext, fn GenericWorkflowFunc[ } dbosCtx.RegisterScheduledWorkflow(fqn, typeErasedWrapper, registrationParams.cronSchedule, registrationParams.maxRetries) } - - return wrappedFunction } /**********************************/ @@ -340,6 +351,7 @@ type contextKey string const workflowStateKey contextKey = "workflowState" type GenericWorkflowFunc[P any, R any] func(dbosCtx DBOSContext, input P) (R, error) +type WorkflowFunc func(ctx DBOSContext, input any) (WorkflowHandle[any], error) type workflowParams struct { workflowID string @@ -388,15 +400,78 @@ func WithWorkflowMaxRetries(maxRetries int) WorkflowOption { } } -// RunAsWorkflow executes the provided function as a durable workflow -// It handles all the features of durable execution for workflows: -// - Workflow ID generation, if needed -// - Child workflow management -// - Enqueuing if a queue name is specified -// - Consistent recording in DBOS system database, returning a handle if the workflow is a terminal status -// The wrapped function is ran into a goroutine and the result is sent to a channel. -// RunAsWorkflow returns a workflow handle that can be used to get the result or status of the workflow. func RunAsWorkflow[P any, R any](dbosCtx DBOSContext, fn GenericWorkflowFunc[P, R], input P, opts ...WorkflowOption) (WorkflowHandle[R], error) { + // Do the durability things + typedErasedWorkflow := WorkflowFunc(func(ctx DBOSContext, input any) (WorkflowHandle[any], error) { + // Dummy typed erased workflow -- we just need the name inside dbosCtx.RunAsWorkflow but want a matching signature + return nil, nil + }) + // Print fn name + fmt.Println("Running workflow function:", runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name()) + // Print t wrapped function name + fmt.Println("Running typed-erased workflow function:", runtime.FuncForPC(reflect.ValueOf(typedErasedWorkflow).Pointer()).Name()) + internalHandle, err := dbosCtx.(*dbosContext).RunAsWorkflow(dbosCtx, typedErasedWorkflow, input, opts...) + if err != nil { + return nil, err + } + + // If we got a polling handle, return it directly + if pollingHandle, ok := internalHandle.(*workflowPollingHandle[any]); ok { + // We need to convert the polling handle to a typed handle + typedPollingHandle := &workflowPollingHandle[R]{ + workflowID: pollingHandle.workflowID, + dbosContext: dbosCtx, + } + return typedPollingHandle, nil + } + + // Channel to receive the outcome from the goroutine + // The buffer size of 1 allows the goroutine to send the outcome without blocking + // In addition it allows the channel to be garbage collected + outcomeChan := make(chan workflowOutcome[R], 1) + + // Create the handle + handle := &workflowHandle[R]{ + workflowID: internalHandle.GetWorkflowID(), + outcomeChan: outcomeChan, + dbosContext: dbosCtx, + } + + // Create workflow state to track step execution + wfState := &workflowState{ + workflowID: internalHandle.GetWorkflowID(), + stepCounter: -1, + } + + // Run the function in a goroutine + augmentUserContext := dbosCtx.(*dbosContext).withValue(workflowStateKey, wfState) + dbosCtx.(*dbosContext).workflowsWg.Add(1) + go func() { + defer dbosCtx.(*dbosContext).workflowsWg.Done() + result, err := fn(augmentUserContext, input) + status := WorkflowStatusSuccess + if err != nil { + status = WorkflowStatusError + } + recordErr := dbosCtx.(*dbosContext).systemDB.UpdateWorkflowOutcome(dbosCtx.(*dbosContext).ctx, updateWorkflowOutcomeDBInput{ + workflowID: internalHandle.GetWorkflowID(), + status: status, + err: err, + output: result, + }) + if recordErr != nil { + outcomeChan <- workflowOutcome[R]{result: *new(R), err: recordErr} + close(outcomeChan) // Close the channel to signal completion + return + } + outcomeChan <- workflowOutcome[R]{result: result, err: err} + close(outcomeChan) // Close the channel to signal completion + }() + + return handle, nil +} + +func (c *dbosContext) RunAsWorkflow(dbosCtx DBOSContext, fn WorkflowFunc, input any, opts ...WorkflowOption) (WorkflowHandle[any], error) { // Apply options to build params params := workflowParams{ applicationVersion: dbosCtx.GetApplicationVersion(), @@ -426,12 +501,12 @@ func RunAsWorkflow[P any, R any](dbosCtx DBOSContext, fn GenericWorkflowFunc[P, // If this is a child workflow that has already been recorded in operations_output, return directly a polling handle if isChildWorkflow { - childWorkflowID, err := dbosCtx.CheckChildWorkflow(parentWorkflowState.workflowID, parentWorkflowState.stepCounter) + childWorkflowID, err := dbosCtx.(*dbosContext).systemDB.CheckChildWorkflow(dbosCtx.(*dbosContext).ctx, parentWorkflowState.workflowID, parentWorkflowState.stepCounter) if err != nil { return nil, newWorkflowExecutionError(parentWorkflowState.workflowID, fmt.Sprintf("checking child workflow: %v", err)) } if childWorkflowID != nil { - return &workflowPollingHandle[R]{workflowID: *childWorkflowID, dbosContext: dbosCtx}, nil + return &workflowPollingHandle[any]{workflowID: *childWorkflowID, dbosContext: dbosCtx}, nil } } @@ -457,82 +532,55 @@ func RunAsWorkflow[P any, R any](dbosCtx DBOSContext, fn GenericWorkflowFunc[P, } // Init status and record child workflow relationship in a single transaction - insertStatusResult, err := dbosCtx.InsertWorkflowStatus(workflowStatus, params.maxRetries) + tx, err := dbosCtx.(*dbosContext).systemDB.(*systemDatabase).pool.Begin(dbosCtx.(*dbosContext).ctx) + if err != nil { + return nil, newWorkflowExecutionError(workflowID, fmt.Sprintf("failed to begin transaction: %v", err)) + } + defer tx.Rollback(dbosCtx.(*dbosContext).ctx) // Rollback if not committed + + // Insert workflow status with transaction + insertInput := insertWorkflowStatusDBInput{ + status: workflowStatus, + maxRetries: params.maxRetries, + tx: tx, + } + insertStatusResult, err := dbosCtx.(*dbosContext).systemDB.InsertWorkflowStatus(dbosCtx.(*dbosContext).ctx, insertInput) if err != nil { return nil, err } - defer insertStatusResult.tx.Rollback(dbosCtx.GetContext()) // Rollback if not committed // Return a polling handle if: we are enqueueing, the workflow is already in a terminal state (success or error), if len(params.queueName) > 0 || insertStatusResult.status == WorkflowStatusSuccess || insertStatusResult.status == WorkflowStatusError { // Commit the transaction to update the number of attempts and/or enact the enqueue - if err := insertStatusResult.tx.Commit(dbosCtx.GetContext()); err != nil { + if err := tx.Commit(dbosCtx.(*dbosContext).ctx); err != nil { return nil, newWorkflowExecutionError(workflowID, fmt.Sprintf("failed to commit transaction: %v", err)) } - return &workflowPollingHandle[R]{workflowID: workflowStatus.ID, dbosContext: dbosCtx}, nil + return &workflowPollingHandle[any]{workflowID: workflowStatus.ID, dbosContext: dbosCtx}, nil } // Record child workflow relationship if this is a child workflow if isChildWorkflow { // Get the step ID that was used for generating the child workflow ID stepID := parentWorkflowState.stepCounter - childInput := RecordChildWorkflowDBInput{ + childInput := recordChildWorkflowDBInput{ parentWorkflowID: parentWorkflowState.workflowID, childWorkflowID: workflowStatus.ID, stepName: runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name(), // Will need to test this stepID: stepID, - tx: insertStatusResult.tx, + tx: tx, } - err = dbosCtx.RecordChildWorkflow(childInput) + err = dbosCtx.(*dbosContext).systemDB.RecordChildWorkflow(dbosCtx.(*dbosContext).ctx, childInput) if err != nil { return nil, newWorkflowExecutionError(parentWorkflowState.workflowID, fmt.Sprintf("recording child workflow: %v", err)) } } - // Channel to receive the outcome from the goroutine - // The buffer size of 1 allows the goroutine to send the outcome without blocking - // In addition it allows the channel to be garbage collected - outcomeChan := make(chan workflowOutcome[R], 1) - - // Create the handle - handle := &workflowHandle[R]{ - workflowID: workflowStatus.ID, - outcomeChan: outcomeChan, - dbosContext: dbosCtx, - } - - // Create workflow state to track step execution - wfState := &workflowState{ - workflowID: workflowStatus.ID, - stepCounter: -1, - } - - // Run the function in a goroutine - augmentUserContext := dbosCtx.WithValue(workflowStateKey, wfState) - dbosCtx.GetWorkflowWg().Add(1) - go func() { - defer dbosCtx.GetWorkflowWg().Done() - result, err := fn(augmentUserContext, input) - status := WorkflowStatusSuccess - if err != nil { - status = WorkflowStatusError - } - recordErr := dbosCtx.UpdateWorkflowOutcome(workflowStatus.ID, status, err, result) - if recordErr != nil { - outcomeChan <- workflowOutcome[R]{result: *new(R), err: recordErr} - close(outcomeChan) // Close the channel to signal completion - return - } - outcomeChan <- workflowOutcome[R]{result: result, err: err} - close(outcomeChan) // Close the channel to signal completion - }() - // Commit the transaction - if err := insertStatusResult.tx.Commit(dbosCtx.GetContext()); err != nil { + if err := tx.Commit(dbosCtx.(*dbosContext).ctx); err != nil { return nil, newWorkflowExecutionError(workflowID, fmt.Sprintf("failed to commit transaction: %v", err)) } - return handle, nil + return &workflowHandleInternal{workflowID: workflowID}, nil } /******************************/ @@ -580,7 +628,7 @@ func WithMaxInterval(maxInterval time.Duration) StepOption { } } -func (e *dbosContext) RunAsStep(fn StepFunc, input any, stepName string, opts ...StepOption) (any, error) { +func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc, input any, stepName string, opts ...StepOption) (any, error) { if fn == nil { return nil, newStepExecutionError("", "", "step function cannot be nil") } @@ -597,21 +645,25 @@ func (e *dbosContext) RunAsStep(fn StepFunc, input any, stepName string, opts .. } // Get workflow state from context - wfState, ok := e.ctx.Value(workflowStateKey).(*workflowState) + wfState, ok := c.ctx.Value(workflowStateKey).(*workflowState) if !ok || wfState == nil { return nil, newStepExecutionError("", stepName, "workflow state not found in context: are you running this step within a workflow?") } // If within a step, just run the function directly if wfState.isWithinStep { - return fn(e.ctx, input) + return fn(c.ctx, input) } // Get next step ID stepID := wfState.NextStepID() // Check the step is cancelled, has already completed, or is called with a different name - recordedOutput, err := e.CheckOperationExecution(wfState.workflowID, stepID, stepName) + recordedOutput, err := c.systemDB.CheckOperationExecution(c.ctx, checkOperationExecutionDBInput{ + workflowID: wfState.workflowID, + stepID: stepID, + stepName: stepName, + }) if err != nil { return nil, newStepExecutionError(wfState.workflowID, stepName, fmt.Sprintf("checking operation execution: %v", err)) } @@ -627,7 +679,7 @@ func (e *dbosContext) RunAsStep(fn StepFunc, input any, stepName string, opts .. } // Spawn a child DBOSContext with the step state - stepCtx := e.WithValue(workflowStateKey, &stepState) + stepCtx := c.withValue(workflowStateKey, &stepState) stepOutput, stepError := fn(stepCtx, input) @@ -648,8 +700,8 @@ func (e *dbosContext) RunAsStep(fn StepFunc, input any, stepName string, opts .. // Wait before retry select { - case <-e.ctx.Done(): - return nil, newStepExecutionError(wfState.workflowID, stepName, fmt.Sprintf("context cancelled during retry: %v", e.ctx.Err())) + case <-c.ctx.Done(): + return nil, newStepExecutionError(wfState.workflowID, stepName, fmt.Sprintf("context cancelled during retry: %v", c.ctx.Err())) case <-time.After(delay): // Continue to retry } @@ -674,14 +726,14 @@ func (e *dbosContext) RunAsStep(fn StepFunc, input any, stepName string, opts .. } // Record the final result - dbInput := RecordOperationResultDBInput{ + dbInput := recordOperationResultDBInput{ workflowID: wfState.workflowID, stepName: stepName, stepID: stepID, err: stepError, output: stepOutput, } - recErr := e.RecordOperationResult(dbInput) + recErr := c.systemDB.RecordOperationResult(c.ctx, dbInput) if recErr != nil { return nil, newStepExecutionError(wfState.workflowID, stepName, fmt.Sprintf("recording step outcome: %v", recErr)) } @@ -709,7 +761,7 @@ func RunAsStep[P any, R any](dbosCtx DBOSContext, fn GenericStepFunc[P, R], inpu } // Call the executor method - result, err := dbosCtx.RunAsStep(typeErasedFn, input, stepName, opts...) + result, err := dbosCtx.RunAsStep(dbosCtx, typeErasedFn, input, stepName, opts...) if err != nil { return *new(R), err } @@ -733,8 +785,8 @@ type WorkflowSendInput[R any] struct { Topic string } -func (e *dbosContext) Send(input WorkflowSendInputInternal) error { - return e.systemDB.Send(e.ctx, input) +func (c *dbosContext) Send(_ DBOSContext, input WorkflowSendInputInternal) error { + return c.systemDB.Send(c.ctx, input) } // Send sends a message to another workflow. @@ -745,7 +797,7 @@ func Send[R any](dbosCtx DBOSContext, input WorkflowSendInput[R]) error { } var typedMessage R gob.Register(typedMessage) - return dbosCtx.Send(WorkflowSendInputInternal{ + return dbosCtx.Send(dbosCtx, WorkflowSendInputInternal{ DestinationID: input.DestinationID, Message: input.Message, Topic: input.Topic, @@ -757,15 +809,15 @@ type WorkflowRecvInput struct { Timeout time.Duration } -func (e *dbosContext) Recv(input WorkflowRecvInput) (any, error) { - return e.systemDB.Recv(e.ctx, input) +func (c *dbosContext) Recv(_ DBOSContext, input WorkflowRecvInput) (any, error) { + return c.systemDB.Recv(c.ctx, input) } func Recv[R any](dbosCtx DBOSContext, input WorkflowRecvInput) (R, error) { if dbosCtx == nil { return *new(R), errors.New("dbosCtx cannot be nil") } - msg, err := dbosCtx.Recv(input) + msg, err := dbosCtx.Recv(dbosCtx, input) if err != nil { return *new(R), err } @@ -781,25 +833,25 @@ func Recv[R any](dbosCtx DBOSContext, input WorkflowRecvInput) (R, error) { return typedMessage, nil } -type WorkflowSetEventInput[R any] struct { +type WorkflowSetEventInputGeneric[R any] struct { Key string Message R } -func (e *dbosContext) SetEvent(input WorkflowSetEventInputInternal) error { - return e.systemDB.SetEvent(e.ctx, input) +func (c *dbosContext) SetEvent(_ DBOSContext, input WorkflowSetEventInput) error { + return c.systemDB.SetEvent(c.ctx, input) } // Sets an event from a workflow. // The event is a key value pair // SetEvent automatically registers the type of R for gob encoding -func SetEvent[R any](dbosCtx DBOSContext, input WorkflowSetEventInput[R]) error { +func SetEvent[R any](dbosCtx DBOSContext, input WorkflowSetEventInputGeneric[R]) error { if dbosCtx == nil { return errors.New("dbosCtx cannot be nil") } var typedMessage R gob.Register(typedMessage) - return dbosCtx.SetEvent(WorkflowSetEventInputInternal{ + return dbosCtx.SetEvent(dbosCtx, WorkflowSetEventInput{ Key: input.Key, Message: input.Message, }) @@ -811,15 +863,15 @@ type WorkflowGetEventInput struct { Timeout time.Duration } -func (e *dbosContext) GetEvent(input WorkflowGetEventInput) (any, error) { - return e.systemDB.GetEvent(e.ctx, input) +func (c *dbosContext) GetEvent(_ DBOSContext, input WorkflowGetEventInput) (any, error) { + return c.systemDB.GetEvent(c.ctx, input) } func GetEvent[R any](dbosCtx DBOSContext, input WorkflowGetEventInput) (R, error) { if dbosCtx == nil { return *new(R), errors.New("dbosCtx cannot be nil") } - value, err := dbosCtx.GetEvent(input) + value, err := dbosCtx.GetEvent(dbosCtx, input) if err != nil { return *new(R), err } @@ -834,15 +886,8 @@ func GetEvent[R any](dbosCtx DBOSContext, input WorkflowGetEventInput) (R, error return typedValue, nil } -func (e *dbosContext) Sleep(duration time.Duration) (time.Duration, error) { - return e.systemDB.Sleep(e.ctx, duration) -} - -func Sleep(dbosCtx DBOSContext, duration time.Duration) (time.Duration, error) { - if dbosCtx == nil { - return 0, errors.New("dbosCtx cannot be nil") - } - return dbosCtx.Sleep(duration) +func (c *dbosContext) Sleep(duration time.Duration) (time.Duration, error) { + return c.systemDB.Sleep(c.ctx, duration) } /***********************************/ @@ -850,29 +895,34 @@ func Sleep(dbosCtx DBOSContext, duration time.Duration) (time.Duration, error) { /***********************************/ // GetWorkflowID retrieves the workflow ID from the context if called within a DBOS workflow -func (e *dbosContext) GetWorkflowID() (string, error) { - wfState, ok := e.ctx.Value(workflowStateKey).(*workflowState) +func (c *dbosContext) GetWorkflowID() (string, error) { + wfState, ok := c.ctx.Value(workflowStateKey).(*workflowState) if !ok || wfState == nil { return "", errors.New("not within a DBOS workflow context") } return wfState.workflowID, nil } -func (e *dbosContext) AwaitWorkflowResult(workflowID string) (any, error) { - return e.systemDB.AwaitWorkflowResult(e.ctx, workflowID) -} - -func (e *dbosContext) RetrieveWorkflow(workflowIDs []string) ([]WorkflowStatus, error) { - return e.systemDB.ListWorkflows(e.ctx, ListWorkflowsDBInput{ - workflowIDs: workflowIDs, +func (c *dbosContext) RetrieveWorkflow(_ DBOSContext, workflowID string) (WorkflowHandle[any], error) { + workflowStatus, err := c.systemDB.ListWorkflows(c.ctx, listWorkflowsDBInput{ + workflowIDs: []string{workflowID}, }) + if err != nil { + return nil, fmt.Errorf("failed to retrieve workflow status: %w", err) + } + if len(workflowStatus) == 0 { + return nil, newNonExistentWorkflowError(workflowID) + } + return &workflowPollingHandle[any]{workflowID: workflowID, dbosContext: c}, nil } func RetrieveWorkflow[R any](dbosCtx DBOSContext, workflowID string) (workflowPollingHandle[R], error) { if dbosCtx == nil { return workflowPollingHandle[R]{}, errors.New("dbosCtx cannot be nil") } - workflowStatus, err := dbosCtx.RetrieveWorkflow([]string{workflowID}) + workflowStatus, err := dbosCtx.(*dbosContext).systemDB.ListWorkflows(dbosCtx.(*dbosContext).ctx, listWorkflowsDBInput{ + workflowIDs: []string{workflowID}, + }) if err != nil { return workflowPollingHandle[R]{}, fmt.Errorf("failed to retrieve workflow status: %w", err) } @@ -881,49 +931,3 @@ func RetrieveWorkflow[R any](dbosCtx DBOSContext, workflowID string) (workflowPo } return workflowPollingHandle[R]{workflowID: workflowID, dbosContext: dbosCtx}, nil } - -// "private" workflow management (used within runAsWorkflow, runAsStep, etc.) - -func (e *dbosContext) CheckChildWorkflow(parentWorkflowID string, stepCounter int) (*string, error) { - return e.systemDB.CheckChildWorkflow(e.ctx, parentWorkflowID, stepCounter) -} - -func (e *dbosContext) InsertWorkflowStatus(status WorkflowStatus, maxRetries int) (*InsertWorkflowResult, error) { - return e.systemDB.InsertWorkflowStatus(e.ctx, insertWorkflowStatusDBInput{ - status: status, - maxRetries: maxRetries, - }) -} - -func (e *dbosContext) RecordChildWorkflow(input RecordChildWorkflowDBInput) error { - return e.systemDB.RecordChildWorkflow(e.ctx, input) -} - -func (e *dbosContext) RecordOperationResult(input RecordOperationResultDBInput) error { - return e.systemDB.RecordOperationResult(e.ctx, input) -} - -func (e *dbosContext) CheckOperationExecution(workflowID string, stepID int, stepName string) (*RecordedResult, error) { - return e.systemDB.CheckOperationExecution(e.ctx, checkOperationExecutionDBInput{ - workflowID: workflowID, - stepID: stepID, - stepName: stepName, - }) -} - -func (e *dbosContext) UpdateWorkflowOutcome(workflowID string, status WorkflowStatusType, err error, output any) error { - return e.systemDB.UpdateWorkflowOutcome(e.ctx, updateWorkflowOutcomeDBInput{ - workflowID: workflowID, - status: status, - err: err, - output: output, - }) -} - -func (e *dbosContext) RecordChildGetResult(input RecordChildGetResultDBInput) error { - return e.systemDB.RecordChildGetResult(e.ctx, input) -} - -func (e *dbosContext) ListWorkflows(input ListWorkflowsDBInput) ([]WorkflowStatus, error) { - return e.systemDB.ListWorkflows(e.ctx, input) -} From e1eb06388c63001bd9db08014bd580254b442b03 Mon Sep 17 00:00:00 2001 From: maxdml Date: Thu, 31 Jul 2025 16:07:50 -0700 Subject: [PATCH 06/30] fix RunAsWorkflow and work on new step interface --- dbos/dbos.go | 28 ++-- dbos/recovery.go | 3 +- dbos/workflow.go | 359 ++++++++++++++++++++++++----------------------- 3 files changed, 204 insertions(+), 186 deletions(-) diff --git a/dbos/dbos.go b/dbos/dbos.go index 31f3fbaf..e9ed8e00 100644 --- a/dbos/dbos.go +++ b/dbos/dbos.go @@ -72,7 +72,7 @@ type DBOSContext interface { RegisterScheduledWorkflow(fqn string, fn WrappedWorkflowFunc, cronSchedule string, maxRetries int) // Workflow operations - RunAsStep(_ DBOSContext, fn StepFunc, input any, stepName string, opts ...StepOption) (any, error) + RunAsStep(_ DBOSContext, fn StepFunc, input ...any) (any, error) RunAsWorkflow(_ DBOSContext, fn WorkflowFunc, input any, opts ...WorkflowOption) (WorkflowHandle[any], error) Send(_ DBOSContext, input WorkflowSendInputInternal) error Recv(_ DBOSContext, input WorkflowRecvInput) (any, error) @@ -91,7 +91,8 @@ type DBOSContext interface { } type dbosContext struct { - ctx context.Context // Embedded context for standard behavior + ctx context.Context + systemDB SystemDatabase adminServer *adminServer config *Config @@ -137,15 +138,22 @@ func (c *dbosContext) Value(key any) any { // Create a new context // This is intended for workflow contexts and step contexts // Hence we only set the relevant fields -func (c *dbosContext) withValue(key, val any) DBOSContext { - return &dbosContext{ - ctx: context.WithValue(c.ctx, key, val), - systemDB: c.systemDB, - applicationVersion: c.applicationVersion, - executorID: c.executorID, - applicationID: c.applicationID, - workflowsWg: c.workflowsWg, +func WithValue(ctx DBOSContext, key, val any) DBOSContext { + if ctx == nil { + return nil + } + // Will do nothing if the concrete type is not dbosContext + if dbosCtx, ok := ctx.(*dbosContext); ok { + return &dbosContext{ + ctx: context.WithValue(dbosCtx.ctx, key, val), + systemDB: dbosCtx.systemDB, + workflowsWg: dbosCtx.workflowsWg, + applicationVersion: dbosCtx.applicationVersion, + executorID: dbosCtx.executorID, + applicationID: dbosCtx.applicationID, + } } + return nil } func (c *dbosContext) getWorkflowScheduler() *cron.Cron { diff --git a/dbos/recovery.go b/dbos/recovery.go index 67f035e9..e5497d6e 100644 --- a/dbos/recovery.go +++ b/dbos/recovery.go @@ -56,8 +56,7 @@ func recoverPendingWorkflows(dbosCtx *dbosContext, executorIDs []string) ([]Work } // Create a workflow context from the executor context - workflowCtx := dbosCtx.withValue(dbosCtx.ctx, nil) - handle, err := registeredWorkflow.wrappedFunction(workflowCtx, workflow.Input, opts...) + handle, err := registeredWorkflow.wrappedFunction(dbosCtx, workflow.Input, opts...) if err != nil { return nil, err } diff --git a/dbos/workflow.go b/dbos/workflow.go index a207e736..f2596ebd 100644 --- a/dbos/workflow.go +++ b/dbos/workflow.go @@ -56,14 +56,14 @@ type WorkflowStatus struct { // workflowState holds the runtime state for a workflow execution type workflowState struct { workflowID string - stepCounter int + stepID int isWithinStep bool } // NextStepID returns the next step ID and increments the counter func (ws *workflowState) NextStepID() int { - ws.stepCounter++ - return ws.stepCounter + ws.stepID++ + return ws.stepID } /********************************/ @@ -82,24 +82,6 @@ type WorkflowHandle[R any] interface { GetWorkflowID() string } -type workflowHandleInternal struct { - workflowID string - workflowStatus WorkflowStatus -} - -// unimplemented -func (h *workflowHandleInternal) GetResult() (any, error) { - return nil, nil -} - -func (h *workflowHandleInternal) GetStatus() (WorkflowStatus, error) { - return h.workflowStatus, nil -} - -func (h *workflowHandleInternal) GetWorkflowID() string { - return h.workflowID -} - // workflowHandle is a concrete implementation of WorkflowHandle type workflowHandle[R any] struct { workflowID string @@ -346,14 +328,15 @@ func RegisterWorkflow[P any, R any](dbosCtx DBOSContext, fn GenericWorkflowFunc[ /******* WORKFLOW FUNCTIONS *******/ /**********************************/ -type contextKey string +type DBOSContextKey string -const workflowStateKey contextKey = "workflowState" +const workflowStateKey DBOSContextKey = "workflowState" type GenericWorkflowFunc[P any, R any] func(dbosCtx DBOSContext, input P) (R, error) -type WorkflowFunc func(ctx DBOSContext, input any) (WorkflowHandle[any], error) +type WorkflowFunc func(ctx DBOSContext, input any) (any, error) type workflowParams struct { + workflowName string workflowID string timeout time.Duration deadline time.Time @@ -400,23 +383,29 @@ func WithWorkflowMaxRetries(maxRetries int) WorkflowOption { } } +func WithWorkflowName(name string) WorkflowOption { + return func(p *workflowParams) { + if len(p.workflowName) == 0 { + p.workflowName = name + } + } +} + func RunAsWorkflow[P any, R any](dbosCtx DBOSContext, fn GenericWorkflowFunc[P, R], input P, opts ...WorkflowOption) (WorkflowHandle[R], error) { - // Do the durability things - typedErasedWorkflow := WorkflowFunc(func(ctx DBOSContext, input any) (WorkflowHandle[any], error) { - // Dummy typed erased workflow -- we just need the name inside dbosCtx.RunAsWorkflow but want a matching signature - return nil, nil + // Set the workflow name in the options -- will not be applied if the user provided a name + opts = append(opts, WithWorkflowName(runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name())) + + typedErasedWorkflow := WorkflowFunc(func(ctx DBOSContext, input any) (any, error) { + return fn(ctx, input.(P)) }) - // Print fn name - fmt.Println("Running workflow function:", runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name()) - // Print t wrapped function name - fmt.Println("Running typed-erased workflow function:", runtime.FuncForPC(reflect.ValueOf(typedErasedWorkflow).Pointer()).Name()) - internalHandle, err := dbosCtx.(*dbosContext).RunAsWorkflow(dbosCtx, typedErasedWorkflow, input, opts...) + + handle, err := dbosCtx.(*dbosContext).RunAsWorkflow(dbosCtx, typedErasedWorkflow, input, opts...) if err != nil { return nil, err } - // If we got a polling handle, return it directly - if pollingHandle, ok := internalHandle.(*workflowPollingHandle[any]); ok { + // If we got a polling handle, return its typed version + if pollingHandle, ok := handle.(*workflowPollingHandle[any]); ok { // We need to convert the polling handle to a typed handle typedPollingHandle := &workflowPollingHandle[R]{ workflowID: pollingHandle.workflowID, @@ -425,50 +414,41 @@ func RunAsWorkflow[P any, R any](dbosCtx DBOSContext, fn GenericWorkflowFunc[P, return typedPollingHandle, nil } - // Channel to receive the outcome from the goroutine - // The buffer size of 1 allows the goroutine to send the outcome without blocking - // In addition it allows the channel to be garbage collected - outcomeChan := make(chan workflowOutcome[R], 1) + // Create a typed channel for the user to get a typed handle + if handle, ok := handle.(*workflowHandle[any]); ok { + typedOutcomeChan := make(chan workflowOutcome[R], 1) - // Create the handle - handle := &workflowHandle[R]{ - workflowID: internalHandle.GetWorkflowID(), - outcomeChan: outcomeChan, - dbosContext: dbosCtx, - } + go func() { + defer close(typedOutcomeChan) + outcome := <-handle.outcomeChan - // Create workflow state to track step execution - wfState := &workflowState{ - workflowID: internalHandle.GetWorkflowID(), - stepCounter: -1, - } + resultErr := outcome.err + var typedResult R + if typedRes, ok := outcome.result.(R); ok { + typedResult = typedRes + } else { // This should never happen + typedResult = *new(R) + typeErr := fmt.Errorf("unexpected result type: expected %T, got %T", *new(R), outcome.result) + resultErr = errors.Join(resultErr, typeErr) + } - // Run the function in a goroutine - augmentUserContext := dbosCtx.(*dbosContext).withValue(workflowStateKey, wfState) - dbosCtx.(*dbosContext).workflowsWg.Add(1) - go func() { - defer dbosCtx.(*dbosContext).workflowsWg.Done() - result, err := fn(augmentUserContext, input) - status := WorkflowStatusSuccess - if err != nil { - status = WorkflowStatusError - } - recordErr := dbosCtx.(*dbosContext).systemDB.UpdateWorkflowOutcome(dbosCtx.(*dbosContext).ctx, updateWorkflowOutcomeDBInput{ - workflowID: internalHandle.GetWorkflowID(), - status: status, - err: err, - output: result, - }) - if recordErr != nil { - outcomeChan <- workflowOutcome[R]{result: *new(R), err: recordErr} - close(outcomeChan) // Close the channel to signal completion - return + typedOutcomeChan <- workflowOutcome[R]{ + result: typedResult, + err: resultErr, + } + }() + + typedHandle := &workflowHandle[R]{ + workflowID: handle.workflowID, + outcomeChan: typedOutcomeChan, + dbosContext: handle.dbosContext, } - outcomeChan <- workflowOutcome[R]{result: result, err: err} - close(outcomeChan) // Close the channel to signal completion - }() - return handle, nil + return typedHandle, nil + } + + // Should never happen + return nil, fmt.Errorf("unexpected workflow handle type: %T", handle) } func (c *dbosContext) RunAsWorkflow(dbosCtx DBOSContext, fn WorkflowFunc, input any, opts ...WorkflowOption) (WorkflowHandle[any], error) { @@ -483,7 +463,6 @@ func (c *dbosContext) RunAsWorkflow(dbosCtx DBOSContext, fn WorkflowFunc, input // Check if we are within a workflow (and thus a child workflow) parentWorkflowState, ok := dbosCtx.Value(workflowStateKey).(*workflowState) isChildWorkflow := ok && parentWorkflowState != nil - // TODO Check if cancelled // Generate an ID for the workflow if not provided @@ -501,7 +480,7 @@ func (c *dbosContext) RunAsWorkflow(dbosCtx DBOSContext, fn WorkflowFunc, input // If this is a child workflow that has already been recorded in operations_output, return directly a polling handle if isChildWorkflow { - childWorkflowID, err := dbosCtx.(*dbosContext).systemDB.CheckChildWorkflow(dbosCtx.(*dbosContext).ctx, parentWorkflowState.workflowID, parentWorkflowState.stepCounter) + childWorkflowID, err := dbosCtx.(*dbosContext).systemDB.CheckChildWorkflow(dbosCtx.(*dbosContext).ctx, parentWorkflowState.workflowID, parentWorkflowState.stepID) if err != nil { return nil, newWorkflowExecutionError(parentWorkflowState.workflowID, fmt.Sprintf("checking child workflow: %v", err)) } @@ -518,7 +497,7 @@ func (c *dbosContext) RunAsWorkflow(dbosCtx DBOSContext, fn WorkflowFunc, input } workflowStatus := WorkflowStatus{ - Name: runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name(), // TODO factor out somewhere else so we dont' have to reflect here + Name: params.workflowName, ApplicationVersion: params.applicationVersion, ExecutorID: dbosCtx.GetExecutorID(), Status: status, @@ -561,11 +540,11 @@ func (c *dbosContext) RunAsWorkflow(dbosCtx DBOSContext, fn WorkflowFunc, input // Record child workflow relationship if this is a child workflow if isChildWorkflow { // Get the step ID that was used for generating the child workflow ID - stepID := parentWorkflowState.stepCounter + stepID := parentWorkflowState.stepID childInput := recordChildWorkflowDBInput{ parentWorkflowID: parentWorkflowState.workflowID, childWorkflowID: workflowStatus.ID, - stepName: runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name(), // Will need to test this + stepName: params.workflowName, stepID: stepID, tx: tx, } @@ -575,113 +554,179 @@ func (c *dbosContext) RunAsWorkflow(dbosCtx DBOSContext, fn WorkflowFunc, input } } + // Channel to receive the outcome from the goroutine + // The buffer size of 1 allows the goroutine to send the outcome without blocking + // In addition it allows the channel to be garbage collected + outcomeChan := make(chan workflowOutcome[any], 1) + + // Create workflow state to track step execution + wfState := &workflowState{ + workflowID: workflowID, + stepID: -1, + } + + // Run the function in a goroutine + workflowCtx := WithValue(dbosCtx, workflowStateKey, wfState) + dbosCtx.(*dbosContext).workflowsWg.Add(1) + go func() { + defer dbosCtx.(*dbosContext).workflowsWg.Done() + result, err := fn(workflowCtx, input) + status := WorkflowStatusSuccess + if err != nil { + status = WorkflowStatusError + } + recordErr := dbosCtx.(*dbosContext).systemDB.UpdateWorkflowOutcome(dbosCtx.(*dbosContext).ctx, updateWorkflowOutcomeDBInput{ + workflowID: workflowID, + status: status, + err: err, + output: result, + }) + if recordErr != nil { + outcomeChan <- workflowOutcome[any]{result: nil, err: recordErr} + close(outcomeChan) // Close the channel to signal completion + return + } + outcomeChan <- workflowOutcome[any]{result: result, err: err} + close(outcomeChan) // Close the channel to signal completion + }() + // Commit the transaction if err := tx.Commit(dbosCtx.(*dbosContext).ctx); err != nil { return nil, newWorkflowExecutionError(workflowID, fmt.Sprintf("failed to commit transaction: %v", err)) } - return &workflowHandleInternal{workflowID: workflowID}, nil + return &workflowHandle[any]{workflowID: workflowID, outcomeChan: outcomeChan, dbosContext: dbosCtx}, nil } /******************************/ /******* STEP FUNCTIONS *******/ /******************************/ -type GenericStepFunc[P any, R any] func(ctx context.Context, input P) (R, error) -type StepFunc func(ctx context.Context, input any) (any, error) +type StepFunc func(ctx context.Context, input ...any) (any, error) +type GenericStepFunc[R any] func(ctx context.Context, input ...any) (R, error) + +const StepParamsKey DBOSContextKey = "stepParams" type StepParams struct { MaxRetries int BackoffFactor float64 BaseInterval time.Duration MaxInterval time.Duration -} - -// StepOption is a functional option for configuring step parameters -type StepOption func(*StepParams) - -// WithStepMaxRetries sets the maximum number of retries for a step -func WithStepMaxRetries(maxRetries int) StepOption { - return func(p *StepParams) { - p.MaxRetries = maxRetries + StepName string +} + +// setStepParamDefaults returns a StepParams struct with all defaults properly set +func setStepParamDefaults(params *StepParams, stepName string) *StepParams { + if params == nil { + return &StepParams{ + MaxRetries: 0, // Default to no retries + BackoffFactor: 2.0, + BaseInterval: 100 * time.Millisecond, // Default base interval + MaxInterval: 5 * time.Second, // Default max interval + StepName: typeErasedStepNameToStepName[stepName], + } } -} -// WithBackoffFactor sets the backoff factor for retries (multiplier for exponential backoff) -func WithBackoffFactor(backoffFactor float64) StepOption { - return func(p *StepParams) { - p.BackoffFactor = backoffFactor + // Set defaults for zero values + if params.BackoffFactor == 0 { + params.BackoffFactor = 2.0 // Default backoff factor } -} - -// WithBaseInterval sets the base delay for the first retry -func WithBaseInterval(baseInterval time.Duration) StepOption { - return func(p *StepParams) { - p.BaseInterval = baseInterval + if params.BaseInterval == 0 { + params.BaseInterval = 100 * time.Millisecond // Default base interval + } + if params.MaxInterval == 0 { + params.MaxInterval = 5 * time.Second // Default max interval } + if params.StepName == "" { + // If the step name is not provided, use the function name + params.StepName = typeErasedStepNameToStepName[stepName] + } + + return params } -// WithMaxInterval sets the maximum delay for retries -func WithMaxInterval(maxInterval time.Duration) StepOption { - return func(p *StepParams) { - p.MaxInterval = maxInterval +var typeErasedStepNameToStepName = make(map[string]string) + +func RunAsStep[R any](dbosCtx DBOSContext, fn GenericStepFunc[R], input ...any) (R, error) { + if dbosCtx == nil { + return *new(R), newStepExecutionError("", "", "dbosCtx cannot be nil") } -} -func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc, input any, stepName string, opts ...StepOption) (any, error) { if fn == nil { - return nil, newStepExecutionError("", "", "step function cannot be nil") + return *new(R), newStepExecutionError("", "", "step function cannot be nil") } - // Apply options to build params with defaults - params := StepParams{ - MaxRetries: 0, - BackoffFactor: 2.0, - BaseInterval: 500 * time.Millisecond, - MaxInterval: 1 * time.Hour, + // Type-erase the function based on its actual type + typeErasedFn := StepFunc(func(ctx context.Context, i ...any) (any, error) { + return fn(ctx, i...) + }) + + typeErasedStepNameToStepName[runtime.FuncForPC(reflect.ValueOf(typeErasedFn).Pointer()).Name()] = runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name() + + // Call the executor method + result, err := dbosCtx.RunAsStep(dbosCtx, typeErasedFn, input...) + if err != nil { + return *new(R), err } - for _, opt := range opts { - opt(¶ms) + + // Type-check and cast the result + typedResult, ok := result.(R) + if !ok { + return *new(R), fmt.Errorf("unexpected result type: expected %T, got %T", *new(R), result) } + return typedResult, nil +} + +func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc, input ...any) (any, error) { // Get workflow state from context wfState, ok := c.ctx.Value(workflowStateKey).(*workflowState) if !ok || wfState == nil { - return nil, newStepExecutionError("", stepName, "workflow state not found in context: are you running this step within a workflow?") + // TODO: try to print step name + return nil, newStepExecutionError("", "", "workflow state not found in context: are you running this step within a workflow?") + } + + if fn == nil { + // TODO: try to print step name + return nil, newStepExecutionError(wfState.workflowID, "", "step function cannot be nil") + } + + // Look up for step parameters in the context and set defaults + params, ok := c.ctx.Value(StepParamsKey).(*StepParams) + if !ok { + params = nil } + params = setStepParamDefaults(params, runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name()) // If within a step, just run the function directly if wfState.isWithinStep { return fn(c.ctx, input) } - // Get next step ID - stepID := wfState.NextStepID() + // Setup step state + stepState := workflowState{ + workflowID: wfState.workflowID, + stepID: wfState.NextStepID(), // crucially, this increments the step ID on the *workflow* state + isWithinStep: true, + } // Check the step is cancelled, has already completed, or is called with a different name recordedOutput, err := c.systemDB.CheckOperationExecution(c.ctx, checkOperationExecutionDBInput{ - workflowID: wfState.workflowID, - stepID: stepID, - stepName: stepName, + workflowID: stepState.workflowID, + stepID: stepState.stepID, + stepName: params.StepName, }) if err != nil { - return nil, newStepExecutionError(wfState.workflowID, stepName, fmt.Sprintf("checking operation execution: %v", err)) + return nil, newStepExecutionError(stepState.workflowID, params.StepName, fmt.Sprintf("checking operation execution: %v", err)) } if recordedOutput != nil { return recordedOutput.output, recordedOutput.err } - // Execute step with retry logic if MaxRetries > 0 - stepState := workflowState{ - workflowID: wfState.workflowID, - stepCounter: wfState.stepCounter, - isWithinStep: true, - } - // Spawn a child DBOSContext with the step state - stepCtx := c.withValue(workflowStateKey, &stepState) + stepCtx := WithValue(c, workflowStateKey, &stepState) - stepOutput, stepError := fn(stepCtx, input) + stepOutput, stepError := fn(stepCtx, input...) // Retry if MaxRetries > 0 and the first execution failed var joinedErrors error @@ -696,12 +741,12 @@ func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc, input any, stepName delay = time.Duration(math.Min(exponentialDelay, float64(params.MaxInterval))) } - getLogger().Error("step failed, retrying", "step_name", stepName, "retry", retry, "max_retries", params.MaxRetries, "delay", delay, "error", stepError) + getLogger().Error("step failed, retrying", "step_name", params.StepName, "retry", retry, "max_retries", params.MaxRetries, "delay", delay, "error", stepError) // Wait before retry select { case <-c.ctx.Done(): - return nil, newStepExecutionError(wfState.workflowID, stepName, fmt.Sprintf("context cancelled during retry: %v", c.ctx.Err())) + return nil, newStepExecutionError(stepState.workflowID, params.StepName, fmt.Sprintf("context cancelled during retry: %v", c.ctx.Err())) case <-time.After(delay): // Continue to retry } @@ -719,7 +764,7 @@ func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc, input any, stepName // If max retries reached, create MaxStepRetriesExceeded error if retry == params.MaxRetries { - stepError = newMaxStepRetriesExceededError(wfState.workflowID, stepName, params.MaxRetries, joinedErrors) + stepError = newMaxStepRetriesExceededError(stepState.workflowID, params.StepName, params.MaxRetries, joinedErrors) break } } @@ -727,54 +772,20 @@ func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc, input any, stepName // Record the final result dbInput := recordOperationResultDBInput{ - workflowID: wfState.workflowID, - stepName: stepName, - stepID: stepID, + workflowID: stepState.workflowID, + stepName: params.StepName, + stepID: stepState.stepID, err: stepError, output: stepOutput, } recErr := c.systemDB.RecordOperationResult(c.ctx, dbInput) if recErr != nil { - return nil, newStepExecutionError(wfState.workflowID, stepName, fmt.Sprintf("recording step outcome: %v", recErr)) + return nil, newStepExecutionError(stepState.workflowID, params.StepName, fmt.Sprintf("recording step outcome: %v", recErr)) } return stepOutput, stepError } -func RunAsStep[P any, R any](dbosCtx DBOSContext, fn GenericStepFunc[P, R], input P, opts ...StepOption) (R, error) { - if dbosCtx == nil { - return *new(R), errors.New("dbosCtx cannot be nil") - } - if fn == nil { - return *new(R), newStepExecutionError("", "", "step function cannot be nil") - } - - stepName := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name() - - // Type-erase the function - typeErasedFn := func(ctx context.Context, input any) (any, error) { - typedInput, ok := input.(P) - if !ok { - return nil, fmt.Errorf("unexpected input type: expected %T, got %T", *new(P), input) - } - return fn(ctx, typedInput) - } - - // Call the executor method - result, err := dbosCtx.RunAsStep(dbosCtx, typeErasedFn, input, stepName, opts...) - if err != nil { - return *new(R), err - } - - // Type-check and cast the result - typedResult, ok := result.(R) - if !ok { - return *new(R), fmt.Errorf("unexpected result type: expected %T, got %T", *new(R), result) - } - - return typedResult, nil -} - /****************************************/ /******* WORKFLOW COMMUNICATIONS ********/ /****************************************/ From c4f0c60b1ca59b1276eb7048fcf4e8a2fe57eeeb Mon Sep 17 00:00:00 2001 From: maxdml Date: Thu, 31 Jul 2025 16:26:36 -0700 Subject: [PATCH 07/30] remove register functions from interface -- simply does nothing if we do not have the right concrete type --- dbos/dbos.go | 11 +++++--- dbos/workflow.go | 68 +++++++++++++++++++++++++++--------------------- 2 files changed, 46 insertions(+), 33 deletions(-) diff --git a/dbos/dbos.go b/dbos/dbos.go index e9ed8e00..e42f5630 100644 --- a/dbos/dbos.go +++ b/dbos/dbos.go @@ -67,10 +67,6 @@ type DBOSContext interface { Launch() error Shutdown() - // Workflow registration - RegisterWorkflow(fqn string, fn WrappedWorkflowFunc, maxRetries int) - RegisterScheduledWorkflow(fqn string, fn WrappedWorkflowFunc, cronSchedule string, maxRetries int) - // Workflow operations RunAsStep(_ DBOSContext, fn StepFunc, input ...any) (any, error) RunAsWorkflow(_ DBOSContext, fn WorkflowFunc, input any, opts ...WorkflowOption) (WorkflowHandle[any], error) @@ -93,6 +89,8 @@ type DBOSContext interface { type dbosContext struct { ctx context.Context + launched bool + systemDB SystemDatabase adminServer *adminServer config *Config @@ -224,6 +222,10 @@ func NewDBOSContext(inputConfig Config) (DBOSContext, error) { } func (c *dbosContext) Launch() error { + if c.launched { + return newInitializationError("DBOS is already launched") + } + // Start the system database c.systemDB.Launch(context.Background()) @@ -269,6 +271,7 @@ func (c *dbosContext) Launch() error { } logger.Info("DBOS initialized", "app_version", c.applicationVersion, "executor_id", c.executorID) + c.launched = true return nil } diff --git a/dbos/workflow.go b/dbos/workflow.go index f2596ebd..ff4bca72 100644 --- a/dbos/workflow.go +++ b/dbos/workflow.go @@ -208,27 +208,41 @@ type workflowRegistryEntry struct { } // Register adds a workflow function to the registry (thread-safe, only once per name) -func (c *dbosContext) RegisterWorkflow(fqn string, fn WrappedWorkflowFunc, maxRetries int) { +func registerWorkflow(dbosCtx DBOSContext, workflowName string, fn WrappedWorkflowFunc, maxRetries int) { + // Skip if we don't have a concrete dbosContext + c, ok := dbosCtx.(*dbosContext) + if !ok { + return + } + c.workflowRegMutex.Lock() defer c.workflowRegMutex.Unlock() - if _, exists := c.workflowRegistry[fqn]; exists { - getLogger().Error("workflow function already registered", "fqn", fqn) - panic(newConflictingRegistrationError(fqn)) + if _, exists := c.workflowRegistry[workflowName]; exists { + getLogger().Error("workflow function already registered", "fqn", workflowName) + panic(newConflictingRegistrationError(workflowName)) } - c.workflowRegistry[fqn] = workflowRegistryEntry{ + fmt.Println("registering workflow", "fqn", workflowName, "max_retries", maxRetries) + + c.workflowRegistry[workflowName] = workflowRegistryEntry{ wrappedFunction: fn, maxRetries: maxRetries, } } -func (c *dbosContext) RegisterScheduledWorkflow(fqn string, fn WrappedWorkflowFunc, cronSchedule string, maxRetries int) { +func registerScheduledWorkflow(dbosCtx DBOSContext, workflowName string, fn WrappedWorkflowFunc, cronSchedule string, maxRetries int) { + // Skip if we don't have a concrete dbosContext + c, ok := dbosCtx.(*dbosContext) + if !ok { + return + } + c.getWorkflowScheduler().Start() var entryID cron.EntryID entryID, err := c.getWorkflowScheduler().AddFunc(cronSchedule, func() { // Execute the workflow on the cron schedule once DBOS is launched - if c == nil { + if !c.launched { return } // Get the scheduled time from the cron entry @@ -238,19 +252,19 @@ func (c *dbosContext) RegisterScheduledWorkflow(fqn string, fn WrappedWorkflowFu // Use Next if Prev is not set, which will only happen for the first run scheduledTime = entry.Next } - wfID := fmt.Sprintf("sched-%s-%s", fqn, scheduledTime) // XXX we can rethink the format + wfID := fmt.Sprintf("sched-%s-%s", workflowName, scheduledTime) // XXX we can rethink the format fn(c, scheduledTime, WithWorkflowID(wfID), WithQueue(_DBOS_INTERNAL_QUEUE_NAME)) }) if err != nil { panic(fmt.Sprintf("failed to register scheduled workflow: %v", err)) } - getLogger().Info("Registered scheduled workflow", "fqn", fqn, "cron_schedule", cronSchedule) + getLogger().Info("Registered scheduled workflow", "fqn", workflowName, "cron_schedule", cronSchedule) } type workflowRegistrationParams struct { cronSchedule string maxRetries int - // Likely we will allow a name here + workflowName string } type workflowRegistrationOption func(*workflowRegistrationParams) @@ -271,19 +285,27 @@ func WithSchedule(schedule string) workflowRegistrationOption { } } +func WithWorkflowName(name string) workflowRegistrationOption { + return func(p *workflowRegistrationParams) { + p.workflowName = name + } +} + // RegisterWorkflow registers the provided function as a durable workflow with the provided DBOSContext workflow registry // If the workflow is a scheduled workflow (determined by the presence of a cron schedule), it will also register a cron job to execute it // RegisterWorkflow is generically typed, allowing us to register the workflow input and output types for gob encoding // The registered workflow is wrapped in a typed-erased wrapper which performs runtime type checks and conversions // To execute the workflow, use DBOSContext.RunAsWorkflow -func RegisterWorkflow[P any, R any](dbosCtx DBOSContext, fn GenericWorkflowFunc[P, R], opts ...workflowRegistrationOption) { +func RegisterWorkflow[P any, R any](dbosCtx DBOSContext, workflowName string, fn GenericWorkflowFunc[P, R], opts ...workflowRegistrationOption) { if dbosCtx == nil { panic("dbosCtx cannot be nil") } registrationParams := workflowRegistrationParams{ - maxRetries: _DEFAULT_MAX_RECOVERY_ATTEMPTS, + maxRetries: _DEFAULT_MAX_RECOVERY_ATTEMPTS, + workflowName: runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name(), } + for _, opt := range opts { opt(®istrationParams) } @@ -291,7 +313,6 @@ func RegisterWorkflow[P any, R any](dbosCtx DBOSContext, fn GenericWorkflowFunc[ if fn == nil { panic("workflow function cannot be nil") } - fqn := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name() // Registry the input/output types for gob encoding var p P @@ -300,10 +321,10 @@ func RegisterWorkflow[P any, R any](dbosCtx DBOSContext, fn GenericWorkflowFunc[ gob.Register(r) // Register a type-erased version of the durable workflow for recovery - typeErasedWrapper := func(ctx DBOSContext, input any, opts ...WorkflowOption) (WorkflowHandle[any], error) { + typeErasedWrapper := WrappedWorkflowFunc(func(ctx DBOSContext, input any, opts ...WorkflowOption) (WorkflowHandle[any], error) { typedInput, ok := input.(P) if !ok { - return nil, newWorkflowUnexpectedInputType(fqn, fmt.Sprintf("%T", typedInput), fmt.Sprintf("%T", input)) + return nil, newWorkflowUnexpectedInputType(workflowName, fmt.Sprintf("%T", typedInput), fmt.Sprintf("%T", input)) } opts = append(opts, WithWorkflowMaxRetries(registrationParams.maxRetries)) @@ -312,15 +333,15 @@ func RegisterWorkflow[P any, R any](dbosCtx DBOSContext, fn GenericWorkflowFunc[ return nil, err } return &workflowPollingHandle[any]{workflowID: handle.GetWorkflowID(), dbosContext: ctx}, nil - } - dbosCtx.RegisterWorkflow(fqn, typeErasedWrapper, registrationParams.maxRetries) + }) + registerWorkflow(dbosCtx, registrationParams.workflowName, typeErasedWrapper, registrationParams.maxRetries) // If this is a scheduled workflow, register a cron job if registrationParams.cronSchedule != "" { if reflect.TypeOf(p) != reflect.TypeOf(time.Time{}) { panic(fmt.Sprintf("scheduled workflow function must accept a time.Time as input, got %T", p)) } - dbosCtx.RegisterScheduledWorkflow(fqn, typeErasedWrapper, registrationParams.cronSchedule, registrationParams.maxRetries) + registerScheduledWorkflow(dbosCtx, registrationParams.workflowName, typeErasedWrapper, registrationParams.cronSchedule, registrationParams.maxRetries) } } @@ -383,18 +404,7 @@ func WithWorkflowMaxRetries(maxRetries int) WorkflowOption { } } -func WithWorkflowName(name string) WorkflowOption { - return func(p *workflowParams) { - if len(p.workflowName) == 0 { - p.workflowName = name - } - } -} - func RunAsWorkflow[P any, R any](dbosCtx DBOSContext, fn GenericWorkflowFunc[P, R], input P, opts ...WorkflowOption) (WorkflowHandle[R], error) { - // Set the workflow name in the options -- will not be applied if the user provided a name - opts = append(opts, WithWorkflowName(runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name())) - typedErasedWorkflow := WorkflowFunc(func(ctx DBOSContext, input any) (any, error) { return fn(ctx, input.(P)) }) From 589b6d3564f0e8567874cc471cf9fdac7255672e Mon Sep 17 00:00:00 2001 From: maxdml Date: Thu, 31 Jul 2025 16:30:37 -0700 Subject: [PATCH 08/30] fix --- dbos/workflow.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dbos/workflow.go b/dbos/workflow.go index ff4bca72..cbd5a7f5 100644 --- a/dbos/workflow.go +++ b/dbos/workflow.go @@ -296,7 +296,7 @@ func WithWorkflowName(name string) workflowRegistrationOption { // RegisterWorkflow is generically typed, allowing us to register the workflow input and output types for gob encoding // The registered workflow is wrapped in a typed-erased wrapper which performs runtime type checks and conversions // To execute the workflow, use DBOSContext.RunAsWorkflow -func RegisterWorkflow[P any, R any](dbosCtx DBOSContext, workflowName string, fn GenericWorkflowFunc[P, R], opts ...workflowRegistrationOption) { +func RegisterWorkflow[P any, R any](dbosCtx DBOSContext, fn GenericWorkflowFunc[P, R], opts ...workflowRegistrationOption) { if dbosCtx == nil { panic("dbosCtx cannot be nil") } From 38417e28734a6fe308220ede965d70475ece0b00 Mon Sep 17 00:00:00 2001 From: maxdml Date: Thu, 31 Jul 2025 19:15:16 -0700 Subject: [PATCH 09/30] fix and nits --- dbos/dbos.go | 2 +- dbos/workflow.go | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/dbos/dbos.go b/dbos/dbos.go index e42f5630..d0f7847a 100644 --- a/dbos/dbos.go +++ b/dbos/dbos.go @@ -61,7 +61,7 @@ func processConfig(inputConfig *Config) (*Config, error) { } type DBOSContext interface { - context.Context // Standard Go context behavior + context.Context // Context Lifecycle Launch() error diff --git a/dbos/workflow.go b/dbos/workflow.go index cbd5a7f5..8989ced7 100644 --- a/dbos/workflow.go +++ b/dbos/workflow.go @@ -223,8 +223,6 @@ func registerWorkflow(dbosCtx DBOSContext, workflowName string, fn WrappedWorkfl panic(newConflictingRegistrationError(workflowName)) } - fmt.Println("registering workflow", "fqn", workflowName, "max_retries", maxRetries) - c.workflowRegistry[workflowName] = workflowRegistryEntry{ wrappedFunction: fn, maxRetries: maxRetries, @@ -324,7 +322,7 @@ func RegisterWorkflow[P any, R any](dbosCtx DBOSContext, fn GenericWorkflowFunc[ typeErasedWrapper := WrappedWorkflowFunc(func(ctx DBOSContext, input any, opts ...WorkflowOption) (WorkflowHandle[any], error) { typedInput, ok := input.(P) if !ok { - return nil, newWorkflowUnexpectedInputType(workflowName, fmt.Sprintf("%T", typedInput), fmt.Sprintf("%T", input)) + return nil, newWorkflowUnexpectedInputType(registrationParams.workflowName, fmt.Sprintf("%T", typedInput), fmt.Sprintf("%T", input)) } opts = append(opts, WithWorkflowMaxRetries(registrationParams.maxRetries)) @@ -702,7 +700,7 @@ func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc, input ...any) (any, } // Look up for step parameters in the context and set defaults - params, ok := c.ctx.Value(StepParamsKey).(*StepParams) + params, ok := c.Value(StepParamsKey).(*StepParams) if !ok { params = nil } From b6eb296d368ca02ddc967dc2362df071dee3eb60 Mon Sep 17 00:00:00 2001 From: maxdml Date: Thu, 31 Jul 2025 19:15:45 -0700 Subject: [PATCH 10/30] WIP: first pass at updating the tests --- dbos/logger_test.go | 16 +- dbos/queues_test.go | 229 ++++++++------- dbos/serialization_test.go | 96 ++++--- dbos/utils_test.go | 19 +- dbos/workflows_test.go | 566 +++++++++++++++++++------------------ 5 files changed, 489 insertions(+), 437 deletions(-) diff --git a/dbos/logger_test.go b/dbos/logger_test.go index d33b8efb..4b7073b8 100644 --- a/dbos/logger_test.go +++ b/dbos/logger_test.go @@ -11,20 +11,20 @@ func TestLogger(t *testing.T) { databaseURL := getDatabaseURL(t) t.Run("Default logger", func(t *testing.T) { - executor, err := NewDBOSContext(Config{ + dbosCtx, err := NewDBOSContext(Config{ DatabaseURL: databaseURL, AppName: "test-app", }) // Create executor with default logger if err != nil { t.Fatalf("Failed to create executor with default logger: %v", err) } - err = executor.Launch() + err = dbosCtx.Launch() if err != nil { t.Fatalf("Failed to launch with default logger: %v", err) } t.Cleanup(func() { - if executor != nil { - executor.Shutdown() + if dbosCtx != nil { + dbosCtx.Shutdown() } }) @@ -47,7 +47,7 @@ func TestLogger(t *testing.T) { // Add some context to the slog logger slogLogger = slogLogger.With("service", "dbos-test", "environment", "test") - executor, err := NewDBOSContext(Config{ + dbosCtx, err := NewDBOSContext(Config{ DatabaseURL: databaseURL, AppName: "test-app", Logger: slogLogger, @@ -55,13 +55,13 @@ func TestLogger(t *testing.T) { if err != nil { t.Fatalf("Failed to create executor with custom logger: %v", err) } - err = executor.Launch() + err = dbosCtx.Launch() if err != nil { t.Fatalf("Failed to launch with custom logger: %v", err) } t.Cleanup(func() { - if executor != nil { - executor.Shutdown() + if dbosCtx != nil { + dbosCtx.Shutdown() } }) diff --git a/dbos/queues_test.go b/dbos/queues_test.go index 9292904b..1ea9592d 100644 --- a/dbos/queues_test.go +++ b/dbos/queues_test.go @@ -44,58 +44,73 @@ func queueWorkflow(ctx DBOSContext, input string) (string, error) { return step1, nil } -func queueStep(ctx context.Context, input string) (string, error) { - return input, nil +func queueStep(ctx context.Context, input ...any) (string, error) { + if len(input) == 0 { + return "", nil + } + str, ok := input[0].(string) + if !ok { + return "", fmt.Errorf("expected string input, got %T", input[0]) + } + return str, nil } func TestWorkflowQueues(t *testing.T) { - executor := setupDBOS(t) + dbosCtx := setupDBOS(t) - // Setup workflows with executor - queueWf := RegisterWorkflow(executor, queueWorkflow) + // Register workflows with dbosContext + RegisterWorkflow(dbosCtx, queueWorkflow) // Create workflow with child that can call the main workflow - queueWfWithChild := RegisterWorkflow[string, string](executor, func(ctx DBOSContext, input string) (string, error) { + queueWorkflowWithChild := func(ctx DBOSContext, input string) (string, error) { // Start a child workflow - childHandle, err := queueWf(ctx, input+"-child") + childHandle, err := RunAsWorkflow(ctx, queueWorkflow, input+"-child") if err != nil { return "", fmt.Errorf("failed to start child workflow: %v", err) } // Get result from child workflow - childResult, err := childHandle.GetResult(ctx) + childResult, err := childHandle.GetResult() if err != nil { return "", fmt.Errorf("failed to get child result: %v", err) } return childResult, nil - }) + } + RegisterWorkflow(dbosCtx, queueWorkflowWithChild) // Create workflow that enqueues another workflow - queueWfThatEnqueues := RegisterWorkflow(executor, func(ctx context.Context, input string) (string, error) { + queueWorkflowThatEnqueues := func(ctx DBOSContext, input string) (string, error) { // Enqueue another workflow to the same queue - enqueuedHandle, err := queueWf(ctx, input+"-enqueued", WithQueue(queue.name)) + enqueuedHandle, err := RunAsWorkflow(ctx, queueWorkflow, input+"-enqueued", WithQueue(queue.name)) if err != nil { return "", fmt.Errorf("failed to enqueue workflow: %v", err) } // Get result from the enqueued workflow - enqueuedResult, err := enqueuedHandle.GetResult(ctx) + enqueuedResult, err := enqueuedHandle.GetResult() if err != nil { return "", fmt.Errorf("failed to get enqueued workflow result: %v", err) } return enqueuedResult, nil - }) + } + RegisterWorkflow(dbosCtx, queueWorkflowThatEnqueues) - enqueueWorkflowDLQ := RegisterWorkflow(executor, func(ctx context.Context, input string) (string, error) { + enqueueWorkflowDLQ := func(ctx DBOSContext, input string) (string, error) { dlqStartEvent.Set() dlqCompleteEvent.Wait() return input, nil - }, WithMaxRetries(dlqMaxRetries)) + } + RegisterWorkflow(dbosCtx, enqueueWorkflowDLQ, WithMaxRetries(dlqMaxRetries)) + + err := dbosCtx.Launch() + if err != nil { + t.Fatalf("failed to launch DBOS instance: %v", err) + } t.Run("EnqueueWorkflow", func(t *testing.T) { - handle, err := queueWf(context.Background(), "test-input", WithQueue(queue.name)) + handle, err := RunAsWorkflow(dbosCtx, queueWorkflow, "test-input", WithQueue(queue.name)) if err != nil { t.Fatalf("failed to enqueue workflow: %v", err) } @@ -105,7 +120,7 @@ func TestWorkflowQueues(t *testing.T) { t.Fatalf("expected handle to be of type workflowPollingHandle, got %T", handle) } - res, err := handle.GetResult(context.Background()) + res, err := handle.GetResult() if err != nil { t.Fatalf("expected no error but got: %v", err) } @@ -113,18 +128,18 @@ func TestWorkflowQueues(t *testing.T) { t.Fatalf("expected workflow result to be 'test-input', got %v", res) } - if !queueEntriesAreCleanedUp() { + if !queueEntriesAreCleanedUp(dbosCtx) { t.Fatal("expected queue entries to be cleaned up after global concurrency test") } }) t.Run("EnqueuedWorkflowStartsChildWorkflow", func(t *testing.T) { - handle, err := queueWfWithChild(context.Background(), "test-input", WithQueue(queue.name)) + handle, err := RunAsWorkflow(dbosCtx, queueWorkflowWithChild, "test-input", WithQueue(queue.name)) if err != nil { t.Fatalf("failed to enqueue workflow with child: %v", err) } - res, err := handle.GetResult(context.Background()) + res, err := handle.GetResult() if err != nil { t.Fatalf("expected no error but got: %v", err) } @@ -135,18 +150,18 @@ func TestWorkflowQueues(t *testing.T) { t.Fatalf("expected workflow result to be '%s', got %v", expectedResult, res) } - if !queueEntriesAreCleanedUp() { + if !queueEntriesAreCleanedUp(dbosCtx) { t.Fatal("expected queue entries to be cleaned up after global concurrency test") } }) t.Run("WorkflowEnqueuesAnotherWorkflow", func(t *testing.T) { - handle, err := queueWfThatEnqueues(context.Background(), "test-input", WithQueue(queue.name)) + handle, err := RunAsWorkflow(dbosCtx, queueWorkflowThatEnqueues, "test-input", WithQueue(queue.name)) if err != nil { t.Fatalf("failed to enqueue workflow that enqueues another workflow: %v", err) } - res, err := handle.GetResult(context.Background()) + res, err := handle.GetResult() if err != nil { t.Fatalf("expected no error but got: %v", err) } @@ -157,7 +172,7 @@ func TestWorkflowQueues(t *testing.T) { t.Fatalf("expected workflow result to be '%s', got %v", expectedResult, res) } - if !queueEntriesAreCleanedUp() { + if !queueEntriesAreCleanedUp(dbosCtx) { t.Fatal("expected queue entries to be cleaned up after global concurrency test") } }) @@ -173,7 +188,7 @@ func TestWorkflowQueues(t *testing.T) { workflowID := "blocking-workflow-test" // Enqueue the workflow for the first time - originalHandle, err := enqueueWorkflowDLQ(context.Background(), "test-input", WithQueue(dlqEnqueueQueue.name), WithWorkflowID(workflowID)) + originalHandle, err := RunAsWorkflow(dbosCtx, enqueueWorkflowDLQ, "test-input", WithQueue(dlqEnqueueQueue.name), WithWorkflowID(workflowID)) if err != nil { t.Fatalf("failed to enqueue blocking workflow: %v", err) } @@ -184,7 +199,7 @@ func TestWorkflowQueues(t *testing.T) { // Try to enqueue the same workflow more times for i := range dlqMaxRetries * 2 { - _, err := enqueueWorkflowDLQ(context.Background(), "test-input", WithQueue(dlqEnqueueQueue.name), WithWorkflowID(workflowID)) + _, err := RunAsWorkflow(dbosCtx, enqueueWorkflowDLQ, "test-input", WithQueue(dlqEnqueueQueue.name), WithWorkflowID(workflowID)) if err != nil { t.Fatalf("failed to enqueue workflow attempt %d: %v", i+1, err) } @@ -204,7 +219,7 @@ func TestWorkflowQueues(t *testing.T) { // Check that the workflow hits DLQ after re-running max retries handles := make([]WorkflowHandle[any], 0, dlqMaxRetries+1) for i := range dlqMaxRetries { - recoveryHandles, err := recoverPendingWorkflows(context.Background(), []string{"local"}) + recoveryHandles, err := recoverPendingWorkflows(dbosCtx.(*dbosContext), []string{"local"}) if err != nil { t.Fatalf("failed to recover pending workflows: %v", err) } @@ -230,7 +245,7 @@ func TestWorkflowQueues(t *testing.T) { // Check the workflow completes dlqCompleteEvent.Set() for _, handle := range handles { - result, err := handle.GetResult(context.Background()) + result, err := handle.GetResult() if err != nil { t.Fatalf("failed to get result from recovered workflow handle: %v", err) } @@ -239,7 +254,7 @@ func TestWorkflowQueues(t *testing.T) { } } - if !queueEntriesAreCleanedUp() { + if !queueEntriesAreCleanedUp(dbosCtx) { t.Fatal("expected queue entries to be cleaned up after successive enqueues test") } }) @@ -254,23 +269,22 @@ var ( ) func TestQueueRecovery(t *testing.T) { - executor := setupDBOS(t) + dbosCtx := setupDBOS(t) - // Create workflows with executor - var recoveryStepWorkflow func(context.Context, int, ...WorkflowOption) (WorkflowHandle[int], error) - var recoveryWorkflow func(context.Context, string, ...WorkflowOption) (WorkflowHandle[[]int], error) + // Create workflows with dbosContext - recoveryStepWorkflow = RegisterWorkflow(executor, func(ctx context.Context, i int) (int, error) { + recoveryStepWorkflowFunc := func(ctx DBOSContext, i int) (int, error) { recoveryStepCounter++ recoveryStepEvents[i].Set() recoveryEvent.Wait() return i, nil - }) + } + RegisterWorkflow(dbosCtx, recoveryStepWorkflowFunc) - recoveryWorkflow = RegisterWorkflow(executor, func(ctx context.Context, input string) ([]int, error) { + recoveryWorkflowFunc := func(ctx DBOSContext, input string) ([]int, error) { handles := make([]WorkflowHandle[int], 0, 5) // 5 queued steps for i := range 5 { - handle, err := recoveryStepWorkflow(ctx, i, WithQueue(recoveryQueue.name)) + handle, err := RunAsWorkflow(ctx, recoveryStepWorkflowFunc, i, WithQueue(recoveryQueue.name)) if err != nil { return nil, fmt.Errorf("failed to enqueue step %d: %v", i, err) } @@ -279,14 +293,15 @@ func TestQueueRecovery(t *testing.T) { results := make([]int, 0, 5) for _, handle := range handles { - result, err := handle.GetResult(ctx) + result, err := handle.GetResult() if err != nil { return nil, fmt.Errorf("failed to get result for handle: %v", err) } results = append(results, result) } return results, nil - }) + } + RegisterWorkflow(dbosCtx, recoveryWorkflowFunc) queuedSteps := 5 @@ -297,7 +312,7 @@ func TestQueueRecovery(t *testing.T) { wfid := uuid.NewString() // Start the workflow. Wait for all steps to start. Verify that they started. - handle, err := recoveryWorkflow(context.Background(), "", WithWorkflowID(wfid)) + handle, err := RunAsWorkflow(dbosCtx, recoveryWorkflowFunc, "", WithWorkflowID(wfid)) if err != nil { t.Fatalf("failed to start workflow: %v", err) } @@ -312,7 +327,7 @@ func TestQueueRecovery(t *testing.T) { } // Recover the workflow, then resume it. - recoveryHandles, err := recoverPendingWorkflows(context.Background(), []string{"local"}) + recoveryHandles, err := recoverPendingWorkflows(dbosCtx.(*dbosContext), []string{"local"}) if err != nil { t.Fatalf("failed to recover pending workflows: %v", err) } @@ -329,7 +344,7 @@ func TestQueueRecovery(t *testing.T) { for _, h := range recoveryHandles { if h.GetWorkflowID() == wfid { // Root workflow case - result, err := h.GetResult(context.Background()) + result, err := h.GetResult() if err != nil { t.Fatalf("failed to get result from recovered root workflow handle: %v", err) } @@ -344,7 +359,7 @@ func TestQueueRecovery(t *testing.T) { } } - result, err := handle.GetResult(context.Background()) + result, err := handle.GetResult() if err != nil { t.Fatalf("failed to get result from original handle: %v", err) } @@ -358,11 +373,11 @@ func TestQueueRecovery(t *testing.T) { } // Rerun the workflow. Because each step is complete, none should start again. - rerunHandle, err := recoveryWorkflow(context.Background(), "test-input", WithWorkflowID(wfid)) + rerunHandle, err := RunAsWorkflow(dbosCtx, recoveryWorkflowFunc, "test-input", WithWorkflowID(wfid)) if err != nil { t.Fatalf("failed to rerun workflow: %v", err) } - rerunResult, err := rerunHandle.GetResult(context.Background()) + rerunResult, err := rerunHandle.GetResult() if err != nil { t.Fatalf("failed to get result from rerun handle: %v", err) } @@ -374,7 +389,7 @@ func TestQueueRecovery(t *testing.T) { t.Fatalf("expected recoveryStepCounter to remain %d, got %d", queuedSteps*2, recoveryStepCounter) } - if !queueEntriesAreCleanedUp() { + if !queueEntriesAreCleanedUp(dbosCtx) { t.Fatal("expected queue entries to be cleaned up after global concurrency test") } } @@ -387,10 +402,10 @@ var ( ) func TestGlobalConcurrency(t *testing.T) { - executor := setupDBOS(t) + dbosContext := setupDBOS(t) - // Create workflow with executor - globalConcurrencyWorkflow := RegisterWorkflow(executor, func(ctx context.Context, input string) (string, error) { + // Create workflow with dbosContext + globalConcurrencyWorkflowFunc := func(ctx DBOSContext, input string) (string, error) { switch input { case "workflow1": workflowEvent1.Set() @@ -399,15 +414,16 @@ func TestGlobalConcurrency(t *testing.T) { workflowEvent2.Set() } return input, nil - }) + } + RegisterWorkflow(dbosContext, globalConcurrencyWorkflowFunc) // Enqueue two workflows - handle1, err := globalConcurrencyWorkflow(context.Background(), "workflow1", WithQueue(globalConcurrencyQueue.name)) + handle1, err := RunAsWorkflow(dbosContext, globalConcurrencyWorkflowFunc, "workflow1", WithQueue(globalConcurrencyQueue.name)) if err != nil { t.Fatalf("failed to enqueue workflow1: %v", err) } - handle2, err := globalConcurrencyWorkflow(context.Background(), "workflow2", WithQueue(globalConcurrencyQueue.name)) + handle2, err := RunAsWorkflow(dbosContext, globalConcurrencyWorkflowFunc, "workflow2", WithQueue(globalConcurrencyQueue.name)) if err != nil { t.Fatalf("failed to enqueue workflow2: %v", err) } @@ -431,7 +447,7 @@ func TestGlobalConcurrency(t *testing.T) { // Allow the first workflow to complete workflowDoneEvent.Set() - result1, err := handle1.GetResult(context.Background()) + result1, err := handle1.GetResult() if err != nil { t.Fatalf("failed to get result from workflow1: %v", err) } @@ -442,14 +458,14 @@ func TestGlobalConcurrency(t *testing.T) { // Wait for the second workflow to start workflowEvent2.Wait() - result2, err := handle2.GetResult(context.Background()) + result2, err := handle2.GetResult() if err != nil { t.Fatalf("failed to get result from workflow2: %v", err) } if result2 != "workflow2" { t.Fatalf("expected result from workflow2 to be 'workflow2', got %v", result2) } - if !queueEntriesAreCleanedUp() { + if !queueEntriesAreCleanedUp(dbosContext) { t.Fatal("expected queue entries to be cleaned up after global concurrency test") } } @@ -471,30 +487,31 @@ var ( ) func TestWorkerConcurrency(t *testing.T) { - executor := setupDBOS(t) + dbosCtx := setupDBOS(t) - // Create workflow with executor - blockingWf := RegisterWorkflow(executor, func(ctx context.Context, i int) (int, error) { + // Create workflow with dbosContext + blockingWfFunc := func(ctx DBOSContext, i int) (int, error) { // Simulate a blocking operation startEvents[i].Set() completeEvents[i].Wait() return i, nil - }) + } + RegisterWorkflow(dbosCtx, blockingWfFunc) // First enqueue four blocking workflows - handle1, err := blockingWf(context.Background(), 0, WithQueue(workerConcurrencyQueue.name), WithWorkflowID("worker-cc-wf-1")) + handle1, err := RunAsWorkflow(dbosCtx, blockingWfFunc, 0, WithQueue(workerConcurrencyQueue.name), WithWorkflowID("worker-cc-wf-1")) if err != nil { t.Fatalf("failed to enqueue blocking workflow 1: %v", err) } - handle2, err := blockingWf(context.Background(), 1, WithQueue(workerConcurrencyQueue.name), WithWorkflowID("worker-cc-wf-2")) + handle2, err := RunAsWorkflow(dbosCtx, blockingWfFunc, 1, WithQueue(workerConcurrencyQueue.name), WithWorkflowID("worker-cc-wf-2")) if err != nil { t.Fatalf("failed to enqueue blocking workflow 2: %v", err) } - _, err = blockingWf(context.Background(), 2, WithQueue(workerConcurrencyQueue.name), WithWorkflowID("worker-cc-wf-3")) + _, err = RunAsWorkflow(dbosCtx, blockingWfFunc, 2, WithQueue(workerConcurrencyQueue.name), WithWorkflowID("worker-cc-wf-3")) if err != nil { t.Fatalf("failed to enqueue blocking workflow 3: %v", err) } - _, err = blockingWf(context.Background(), 3, WithQueue(workerConcurrencyQueue.name), WithWorkflowID("worker-cc-wf-4")) + _, err = RunAsWorkflow(dbosCtx, blockingWfFunc, 3, WithQueue(workerConcurrencyQueue.name), WithWorkflowID("worker-cc-wf-4")) if err != nil { t.Fatalf("failed to enqueue blocking workflow 4: %v", err) } @@ -505,7 +522,7 @@ func TestWorkerConcurrency(t *testing.T) { if startEvents[1].IsSet || startEvents[2].IsSet || startEvents[3].IsSet { t.Fatal("expected only blocking workflow 1 to start, but others have started") } - workflows, err := dbos.systemDB.ListWorkflows(context.Background(), ListWorkflowsDBInput{ + workflows, err := dbosCtx.(*dbosContext).systemDB.ListWorkflows(context.Background(), listWorkflowsDBInput{ status: []WorkflowStatusType{WorkflowStatusEnqueued}, queueName: workerConcurrencyQueue.name, }) @@ -516,12 +533,12 @@ func TestWorkerConcurrency(t *testing.T) { t.Fatalf("expected 3 workflows to be enqueued, got %d", len(workflows)) } - // Stop the queue runner before changing executor ID to avoid race conditions - stopQueueRunner() - // Change the executor ID to a different value - dbos.executorID = "worker-2" + // Stop the queue runner before changing dbosContext ID to avoid race conditions + stopQueueRunner(dbosCtx) + // Change the dbosContext ID to a different value + dbosCtx.(*dbosContext).executorID = "worker-2" // Restart the queue runner - restartQueueRunner() + restartQueueRunner(dbosCtx) // Wait for the second workflow to start on the second worker startEvents[1].Wait() @@ -529,7 +546,7 @@ func TestWorkerConcurrency(t *testing.T) { if startEvents[2].IsSet || startEvents[3].IsSet { t.Fatal("expected only blocking workflow 2 to start, but others have started") } - workflows, err = dbos.systemDB.ListWorkflows(context.Background(), ListWorkflowsDBInput{ + workflows, err = dbosCtx.(*dbosContext).systemDB.ListWorkflows(context.Background(), listWorkflowsDBInput{ status: []WorkflowStatusType{WorkflowStatusEnqueued}, queueName: workerConcurrencyQueue.name, }) @@ -542,26 +559,26 @@ func TestWorkerConcurrency(t *testing.T) { // Unlock workflow 1, check wf 3 starts, check 4 stays blocked completeEvents[0].Set() - result1, err := handle1.GetResult(context.Background()) + result1, err := handle1.GetResult() if err != nil { t.Fatalf("failed to get result from blocking workflow 1: %v", err) } if result1 != 0 { t.Fatalf("expected result from blocking workflow 1 to be 0, got %v", result1) } - // Stop the queue runner before changing executor ID to avoid race conditions - stopQueueRunner() - // Change the executor again and wait for the third workflow to start - dbos.executorID = "local" + // Stop the queue runner before changing dbosContext ID to avoid race conditions + stopQueueRunner(dbosCtx) + // Change the dbosContext again and wait for the third workflow to start + dbosCtx.(*dbosContext).executorID = "local" // Restart the queue runner - restartQueueRunner() + restartQueueRunner(dbosCtx) startEvents[2].Wait() // Ensure the fourth workflow is not started yet if startEvents[3].IsSet { t.Fatal("expected only blocking workflow 3 to start, but workflow 4 has started") } // Check that only one workflow is pending - workflows, err = dbos.systemDB.ListWorkflows(context.Background(), ListWorkflowsDBInput{ + workflows, err = dbosCtx.(*dbosContext).systemDB.ListWorkflows(context.Background(), listWorkflowsDBInput{ status: []WorkflowStatusType{WorkflowStatusEnqueued}, queueName: workerConcurrencyQueue.name, }) @@ -574,22 +591,22 @@ func TestWorkerConcurrency(t *testing.T) { // Unlock workflow 2 and check wf 4 starts completeEvents[1].Set() - result2, err := handle2.GetResult(context.Background()) + result2, err := handle2.GetResult() if err != nil { t.Fatalf("failed to get result from blocking workflow 2: %v", err) } if result2 != 1 { t.Fatalf("expected result from blocking workflow 2 to be 1, got %v", result2) } - // Stop the queue runner before changing executor ID to avoid race conditions - stopQueueRunner() - // change executor again and wait for the fourth workflow to start - dbos.executorID = "worker-2" + // Stop the queue runner before changing dbosContext ID to avoid race conditions + stopQueueRunner(dbosCtx) + // change dbosContext again and wait for the fourth workflow to start + dbosCtx.(*dbosContext).executorID = "worker-2" // Restart the queue runner - restartQueueRunner() + restartQueueRunner(dbosCtx) startEvents[3].Wait() // Check no workflow is enqueued - workflows, err = dbos.systemDB.ListWorkflows(context.Background(), ListWorkflowsDBInput{ + workflows, err = dbosCtx.(*dbosContext).systemDB.ListWorkflows(context.Background(), listWorkflowsDBInput{ status: []WorkflowStatusType{WorkflowStatusEnqueued}, queueName: workerConcurrencyQueue.name, }) @@ -604,11 +621,11 @@ func TestWorkerConcurrency(t *testing.T) { completeEvents[2].Set() completeEvents[3].Set() - if !queueEntriesAreCleanedUp() { + if !queueEntriesAreCleanedUp(dbosCtx) { t.Fatal("expected queue entries to be cleaned up after global concurrency test") } - dbos.executorID = "local" // Reset executor ID for future tests + dbosCtx.(*dbosContext).executorID = "local" // Reset executor ID for future tests } var ( @@ -620,26 +637,28 @@ var ( ) func TestWorkerConcurrencyXRecovery(t *testing.T) { - executor := setupDBOS(t) + dbosCtx := setupDBOS(t) - // Create workflows with executor - workerConcurrencyRecoveryBlockingWf1 := RegisterWorkflow(executor, func(ctx context.Context, input string) (string, error) { + // Create workflows with dbosContext + workerConcurrencyRecoveryBlockingWf1 := func(ctx DBOSContext, input string) (string, error) { workerConcurrencyRecoveryStartEvent1.Set() workerConcurrencyRecoveryCompleteEvent1.Wait() return input, nil - }) - workerConcurrencyRecoveryBlockingWf2 := RegisterWorkflow(executor, func(ctx context.Context, input string) (string, error) { + } + RegisterWorkflow(dbosCtx, workerConcurrencyRecoveryBlockingWf1) + workerConcurrencyRecoveryBlockingWf2 := func(ctx DBOSContext, input string) (string, error) { workerConcurrencyRecoveryStartEvent2.Set() workerConcurrencyRecoveryCompleteEvent2.Wait() return input, nil - }) + } + RegisterWorkflow(dbosCtx, workerConcurrencyRecoveryBlockingWf2) // Enqueue two workflows on a queue with worker concurrency = 1 - handle1, err := workerConcurrencyRecoveryBlockingWf1(context.Background(), "workflow1", WithQueue(workerConcurrencyRecoveryQueue.name), WithWorkflowID("worker-cc-x-recovery-wf-1")) + handle1, err := RunAsWorkflow(dbosCtx, workerConcurrencyRecoveryBlockingWf1, "workflow1", WithQueue(workerConcurrencyRecoveryQueue.name), WithWorkflowID("worker-cc-x-recovery-wf-1")) if err != nil { t.Fatalf("failed to enqueue blocking workflow 1: %v", err) } - handle2, err := workerConcurrencyRecoveryBlockingWf2(context.Background(), "workflow2", WithQueue(workerConcurrencyRecoveryQueue.name), WithWorkflowID("worker-cc-x-recovery-wf-2")) + handle2, err := RunAsWorkflow(dbosCtx, workerConcurrencyRecoveryBlockingWf2, "workflow2", WithQueue(workerConcurrencyRecoveryQueue.name), WithWorkflowID("worker-cc-x-recovery-wf-2")) if err != nil { t.Fatalf("failed to enqueue blocking workflow 2: %v", err) } @@ -665,7 +684,7 @@ func TestWorkerConcurrencyXRecovery(t *testing.T) { } // Now, manually call the recoverPendingWorkflows method - recoveryHandles, err := recoverPendingWorkflows(context.Background(), []string{"local"}) + recoveryHandles, err := recoverPendingWorkflows(dbosCtx.(*dbosContext), []string{"local"}) if err != nil { t.Fatalf("failed to recover pending workflows: %v", err) } @@ -704,7 +723,7 @@ func TestWorkerConcurrencyXRecovery(t *testing.T) { workerConcurrencyRecoveryCompleteEvent2.Set() // Get result from first workflow - result1, err := handle1.GetResult(context.Background()) + result1, err := handle1.GetResult() if err != nil { t.Fatalf("failed to get result from workflow1: %v", err) } @@ -713,7 +732,7 @@ func TestWorkerConcurrencyXRecovery(t *testing.T) { } // Get result from second workflow - result2, err := handle2.GetResult(context.Background()) + result2, err := handle2.GetResult() if err != nil { t.Fatalf("failed to get result from workflow2: %v", err) } @@ -722,7 +741,7 @@ func TestWorkerConcurrencyXRecovery(t *testing.T) { } // Ensure queueEntriesAreCleanedUp is set to true - if !queueEntriesAreCleanedUp() { + if !queueEntriesAreCleanedUp(dbosCtx) { t.Fatal("expected queue entries to be cleaned up after worker concurrency recovery test") } } @@ -731,15 +750,15 @@ var ( rateLimiterQueue = NewWorkflowQueue("test-rate-limiter-queue", WithRateLimiter(&RateLimiter{Limit: 5, Period: 1.8})) ) -func rateLimiterTestWorkflow(ctx context.Context, _ string) (time.Time, error) { +func rateLimiterTestWorkflow(ctx DBOSContext, _ string) (time.Time, error) { return time.Now(), nil // Return current time } func TestQueueRateLimiter(t *testing.T) { - executor := setupDBOS(t) + dbosCtx := setupDBOS(t) - // Create workflow with executor - rateLimiterWorkflow := RegisterWorkflow(executor, rateLimiterTestWorkflow) + // Create workflow with dbosContext + RegisterWorkflow(dbosCtx, rateLimiterTestWorkflow) limit := 5 period := 1.8 @@ -753,7 +772,7 @@ func TestQueueRateLimiter(t *testing.T) { // executed simultaneously, followed by a wait of the period, // followed by the next wave. for i := 0; i < limit*numWaves; i++ { - handle, err := rateLimiterWorkflow(context.Background(), "", WithQueue(rateLimiterQueue.name)) + handle, err := RunAsWorkflow(dbosCtx, rateLimiterTestWorkflow, "", WithQueue(rateLimiterQueue.name)) if err != nil { t.Fatalf("failed to enqueue workflow %d: %v", i, err) } @@ -762,7 +781,7 @@ func TestQueueRateLimiter(t *testing.T) { // Get results from all workflows for _, handle := range handles { - result, err := handle.GetResult(context.Background()) + result, err := handle.GetResult() if err != nil { t.Fatalf("failed to get result from workflow: %v", err) } @@ -825,7 +844,7 @@ func TestQueueRateLimiter(t *testing.T) { } // Verify all queue entries eventually get cleaned up. - if !queueEntriesAreCleanedUp() { + if !queueEntriesAreCleanedUp(dbosCtx) { t.Fatal("expected queue entries to be cleaned up after rate limiter test") } } diff --git a/dbos/serialization_test.go b/dbos/serialization_test.go index a171053c..7ab3a3be 100644 --- a/dbos/serialization_test.go +++ b/dbos/serialization_test.go @@ -18,12 +18,19 @@ import ( */ // Builtin types -func encodingStepBuiltinTypes(_ context.Context, input int) (int, error) { - return input, errors.New("step error") +func encodingStepBuiltinTypes(_ context.Context, input ...any) (int, error) { + if len(input) == 0 { + return 0, errors.New("step error") + } + val, ok := input[0].(int) + if !ok { + return 0, errors.New("step error") + } + return val, errors.New("step error") } -func encodingWorkflowBuiltinTypes(ctx context.Context, input string) (string, error) { - stepResult, err := RunAsStep(ctx, dbos, encodingStepBuiltinTypes, 123) +func encodingWorkflowBuiltinTypes(ctx DBOSContext, input string) (string, error) { + stepResult, err := RunAsStep(ctx, encodingStepBuiltinTypes, 123) return fmt.Sprintf("%d", stepResult), fmt.Errorf("workflow error: %v", err) } @@ -48,16 +55,23 @@ type SimpleStruct struct { B int } -func encodingWorkflowStruct(ctx context.Context, input WorkflowInputStruct) (StepOutputStruct, error) { - return RunAsStep(ctx, dbos, encodingStepStruct, StepInputStruct{ +func encodingWorkflowStruct(ctx DBOSContext, input WorkflowInputStruct) (StepOutputStruct, error) { + return RunAsStep(ctx, encodingStepStruct, StepInputStruct{ A: input.A, B: fmt.Sprintf("%d", input.B), }) } -func encodingStepStruct(ctx context.Context, input StepInputStruct) (StepOutputStruct, error) { +func encodingStepStruct(ctx context.Context, input ...any) (StepOutputStruct, error) { + if len(input) == 0 { + return StepOutputStruct{}, nil + } + stepInput, ok := input[0].(StepInputStruct) + if !ok { + return StepOutputStruct{}, nil + } return StepOutputStruct{ - A: input, + A: stepInput, B: "processed by encodingStepStruct", }, nil } @@ -65,19 +79,19 @@ func encodingStepStruct(ctx context.Context, input StepInputStruct) (StepOutputS func TestWorkflowEncoding(t *testing.T) { executor := setupDBOS(t) - // Create workflows with executor - builtinWf := RegisterWorkflow(executor, encodingWorkflowBuiltinTypes) - structWf := RegisterWorkflow(executor, encodingWorkflowStruct) + // Register workflows with executor + RegisterWorkflow(executor, encodingWorkflowBuiltinTypes) + RegisterWorkflow(executor, encodingWorkflowStruct) t.Run("BuiltinTypes", func(t *testing.T) { // Test a workflow that uses a built-in type (string) - directHandle, err := builtinWf(context.Background(), "test") + directHandle, err := RunAsWorkflow(executor, encodingWorkflowBuiltinTypes, "test") if err != nil { t.Fatalf("failed to execute workflow: %v", err) } // Test result and error from direct handle - directHandleResult, err := directHandle.GetResult(context.Background()) + directHandleResult, err := directHandle.GetResult() if directHandleResult != "123" { t.Fatalf("expected direct handle result to be '123', got %v", directHandleResult) } @@ -86,11 +100,11 @@ func TestWorkflowEncoding(t *testing.T) { } // Test result from polling handle - retrieveHandler, err := RetrieveWorkflow[string](dbos, directHandle.GetWorkflowID()) + retrieveHandler, err := RetrieveWorkflow[string](executor.(*dbosContext), directHandle.GetWorkflowID()) if err != nil { t.Fatalf("failed to retrieve workflow: %v", err) } - retrievedResult, err := retrieveHandler.GetResult(context.Background()) + retrievedResult, err := retrieveHandler.GetResult() if retrievedResult != "123" { t.Fatalf("expected retrieved result to be '123', got %v", retrievedResult) } @@ -99,7 +113,7 @@ func TestWorkflowEncoding(t *testing.T) { } // Test results from ListWorkflows - workflows, err := dbos.systemDB.ListWorkflows(context.Background(), listWorkflowsDBInput{ + workflows, err := executor.(*dbosContext).systemDB.ListWorkflows(context.Background(), listWorkflowsDBInput{ workflowIDs: []string{directHandle.GetWorkflowID()}, }) if err != nil { @@ -137,7 +151,7 @@ func TestWorkflowEncoding(t *testing.T) { } // Test results from GetWorkflowSteps - steps, err := dbos.systemDB.GetWorkflowSteps(context.Background(), directHandle.GetWorkflowID()) + steps, err := executor.(*dbosContext).systemDB.GetWorkflowSteps(context.Background(), directHandle.GetWorkflowID()) if err != nil { t.Fatalf("failed to get workflow steps: %v", err) } @@ -170,13 +184,13 @@ func TestWorkflowEncoding(t *testing.T) { B: 456, } - directHandle, err := structWf(context.Background(), input) + directHandle, err := RunAsWorkflow(executor, encodingWorkflowStruct, input) if err != nil { t.Fatalf("failed to execute step workflow: %v", err) } // Test result from direct handle - directResult, err := directHandle.GetResult(context.Background()) + directResult, err := directHandle.GetResult() if err != nil { t.Fatalf("expected no error but got: %v", err) } @@ -194,11 +208,11 @@ func TestWorkflowEncoding(t *testing.T) { } // Test result from polling handle - retrieveHandler, err := RetrieveWorkflow[StepOutputStruct](dbos, directHandle.GetWorkflowID()) + retrieveHandler, err := RetrieveWorkflow[StepOutputStruct](executor.(*dbosContext), directHandle.GetWorkflowID()) if err != nil { t.Fatalf("failed to retrieve step workflow: %v", err) } - retrievedResult, err := retrieveHandler.GetResult(context.Background()) + retrievedResult, err := retrieveHandler.GetResult() if err != nil { t.Fatalf("expected no error but got: %v", err) } @@ -216,7 +230,7 @@ func TestWorkflowEncoding(t *testing.T) { } // Test results from ListWorkflows - workflows, err := dbos.systemDB.ListWorkflows(context.Background(), listWorkflowsDBInput{ + workflows, err := executor.(*dbosContext).systemDB.ListWorkflows(context.Background(), listWorkflowsDBInput{ workflowIDs: []string{directHandle.GetWorkflowID()}, }) if err != nil { @@ -258,7 +272,7 @@ func TestWorkflowEncoding(t *testing.T) { } // Test results from GetWorkflowSteps - steps, err := dbos.systemDB.GetWorkflowSteps(context.Background(), directHandle.GetWorkflowID()) + steps, err := executor.(*dbosContext).systemDB.GetWorkflowSteps(context.Background(), directHandle.GetWorkflowID()) if err != nil { t.Fatalf("failed to get workflow steps: %v", err) } @@ -300,7 +314,7 @@ type UserDefinedEventData struct { } `json:"details"` } -func setEventUserDefinedTypeWorkflow(ctx context.Context, input string) (string, error) { +func setEventUserDefinedTypeWorkflow(ctx DBOSContext, input string) (string, error) { eventData := UserDefinedEventData{ ID: 42, Name: "test-event", @@ -313,7 +327,7 @@ func setEventUserDefinedTypeWorkflow(ctx context.Context, input string) (string, }, } - err := SetEvent(ctx, dbos, WorkflowSetEventInputGeneric[UserDefinedEventData]{Key: input, Message: eventData}) + err := SetEvent(ctx, WorkflowSetEventInputGeneric[UserDefinedEventData]{Key: input, Message: eventData}) if err != nil { return "", err } @@ -323,18 +337,18 @@ func setEventUserDefinedTypeWorkflow(ctx context.Context, input string) (string, func TestSetEventSerialize(t *testing.T) { executor := setupDBOS(t) - // Create workflow with executor - setEventUserDefinedTypeWf := RegisterWorkflow(executor, setEventUserDefinedTypeWorkflow) + // Register workflow with executor + RegisterWorkflow(executor, setEventUserDefinedTypeWorkflow) t.Run("SetEventUserDefinedType", func(t *testing.T) { // Start a workflow that sets an event with a user-defined type - setHandle, err := setEventUserDefinedTypeWf(context.Background(), "user-defined-key") + setHandle, err := RunAsWorkflow(executor, setEventUserDefinedTypeWorkflow, "user-defined-key") if err != nil { t.Fatalf("failed to start workflow with user-defined event type: %v", err) } // Wait for the workflow to complete - result, err := setHandle.GetResult(context.Background()) + result, err := setHandle.GetResult() if err != nil { t.Fatalf("failed to get result from user-defined event workflow: %v", err) } @@ -343,7 +357,7 @@ func TestSetEventSerialize(t *testing.T) { } // Retrieve the event to verify it was properly serialized and can be deserialized - retrievedEvent, err := GetEvent[UserDefinedEventData](context.Background(), dbos, WorkflowGetEventInput{ + retrievedEvent, err := GetEvent[UserDefinedEventData](executor, WorkflowGetEventInput{ TargetWorkflowID: setHandle.GetWorkflowID(), Key: "user-defined-key", Timeout: 3 * time.Second, @@ -374,7 +388,7 @@ func TestSetEventSerialize(t *testing.T) { }) } -func sendUserDefinedTypeWorkflow(ctx context.Context, destinationID string) (string, error) { +func sendUserDefinedTypeWorkflow(ctx DBOSContext, destinationID string) (string, error) { // Create an instance of our user-defined type inside the workflow sendData := UserDefinedEventData{ ID: 42, @@ -390,7 +404,7 @@ func sendUserDefinedTypeWorkflow(ctx context.Context, destinationID string) (str // Send should automatically register this type with gob // Note the explicit type parameter since compiler cannot infer UserDefinedEventData from string input - err := Send(ctx, dbos, WorkflowSendInput[UserDefinedEventData]{ + err := Send(ctx, WorkflowSendInput[UserDefinedEventData]{ DestinationID: destinationID, Topic: "user-defined-topic", Message: sendData, @@ -401,9 +415,9 @@ func sendUserDefinedTypeWorkflow(ctx context.Context, destinationID string) (str return "user-defined-message-sent", nil } -func recvUserDefinedTypeWorkflow(ctx context.Context, input string) (UserDefinedEventData, error) { +func recvUserDefinedTypeWorkflow(ctx DBOSContext, input string) (UserDefinedEventData, error) { // Receive the user-defined type message - result, err := Recv[UserDefinedEventData](ctx, dbos, WorkflowRecvInput{ + result, err := Recv[UserDefinedEventData](ctx, WorkflowRecvInput{ Topic: "user-defined-topic", Timeout: 3 * time.Second, }) @@ -413,25 +427,25 @@ func recvUserDefinedTypeWorkflow(ctx context.Context, input string) (UserDefined func TestSendSerialize(t *testing.T) { executor := setupDBOS(t) - // Create workflows with executor - sendUserDefinedTypeWf := RegisterWorkflow(executor, sendUserDefinedTypeWorkflow) - recvUserDefinedTypeWf := RegisterWorkflow(executor, recvUserDefinedTypeWorkflow) + // Register workflows with executor + RegisterWorkflow(executor, sendUserDefinedTypeWorkflow) + RegisterWorkflow(executor, recvUserDefinedTypeWorkflow) t.Run("SendUserDefinedType", func(t *testing.T) { // Start a receiver workflow first - recvHandle, err := recvUserDefinedTypeWf(context.Background(), "recv-input") + recvHandle, err := RunAsWorkflow(executor, recvUserDefinedTypeWorkflow, "recv-input") if err != nil { t.Fatalf("failed to start receive workflow: %v", err) } // Start a sender workflow that sends a message with a user-defined type - sendHandle, err := sendUserDefinedTypeWf(context.Background(), recvHandle.GetWorkflowID()) + sendHandle, err := RunAsWorkflow(executor, sendUserDefinedTypeWorkflow, recvHandle.GetWorkflowID()) if err != nil { t.Fatalf("failed to start workflow with user-defined send type: %v", err) } // Wait for the sender workflow to complete - sendResult, err := sendHandle.GetResult(context.Background()) + sendResult, err := sendHandle.GetResult() if err != nil { t.Fatalf("failed to get result from user-defined send workflow: %v", err) } @@ -440,7 +454,7 @@ func TestSendSerialize(t *testing.T) { } // Wait for the receiver workflow to complete and get the message - receivedData, err := recvHandle.GetResult(context.Background()) + receivedData, err := recvHandle.GetResult() if err != nil { t.Fatalf("failed to get result from receive workflow: %v", err) } diff --git a/dbos/utils_test.go b/dbos/utils_test.go index 610e45e1..48bad8e3 100644 --- a/dbos/utils_test.go +++ b/dbos/utils_test.go @@ -54,32 +54,26 @@ func setupDBOS(t *testing.T) DBOSContext { t.Fatalf("failed to drop test database: %v", err) } - executor, err := NewDBOSContext(Config{ + dbosContext, err := NewDBOSContext(Config{ DatabaseURL: databaseURL, AppName: "test-app", }) if err != nil { t.Fatalf("failed to create DBOS instance: %v", err) } - - err = executor.Launch() - if err != nil { - t.Fatalf("failed to launch DBOS instance: %v", err) - } - - if executor == nil { + if dbosContext == nil { t.Fatal("expected DBOS instance but got nil") } // Register cleanup to run after test completes t.Cleanup(func() { fmt.Println("Cleaning up DBOS instance...") - if executor != nil { - executor.Shutdown() + if dbosContext != nil { + dbosContext.Shutdown() } }) - return executor + return dbosContext } /* Event struct provides a simple synchronization primitive that can be used to signal between goroutines. */ @@ -135,6 +129,7 @@ func restartQueueRunner(executor DBOSContext) { if executor != nil { exec := executor.(*dbosContext) // Create new context and cancel function + // FIXME: cancellation now has to go through the DBOSContext ctx, cancel := context.WithCancel(context.Background()) exec.queueRunnerCtx = ctx exec.queueRunnerCancelFunc = cancel @@ -143,7 +138,7 @@ func restartQueueRunner(executor DBOSContext) { // Start the queue runner in a goroutine go func() { defer close(exec.queueRunnerDone) - queueRunner(ctx, exec) + queueRunner(exec) }() } } diff --git a/dbos/workflows_test.go b/dbos/workflows_test.go index b17c70f4..44d2e288 100644 --- a/dbos/workflows_test.go +++ b/dbos/workflows_test.go @@ -6,8 +6,7 @@ Test workflow and steps features [x] workflow idempotency [x] workflow DLQ [] workflow conflicting name -[] workflow timeout -[] workflow deadlines +[] workflow timeouts & deadlines (including child workflows) */ import ( @@ -24,33 +23,42 @@ import ( // Global counter for idempotency testing var idempotencyCounter int64 -func simpleWorkflow(ctxt context.Context, input string) (string, error) { +func simpleWorkflow(dbosCtx DBOSContext, input string) (string, error) { return input, nil } -func simpleWorkflowError(ctx context.Context, input string) (int, error) { +func simpleWorkflowError(dbosCtx DBOSContext, input string) (int, error) { return 0, fmt.Errorf("failure") } -func simpleWorkflowWithStep(ctx context.Context, input string) (string, error) { - return RunAsStep(ctx, dbos, simpleStep, input) +func simpleWorkflowWithStep(dbosCtx DBOSContext, input string) (string, error) { + return RunAsStep(dbosCtx, simpleStep, input) } -func simpleStep(ctx context.Context, input string) (string, error) { +func simpleStep(ctx context.Context, input ...any) (string, error) { return "from step", nil } -func simpleStepError(ctx context.Context, input string) (string, error) { +func simpleStepError(ctx context.Context, input ...any) (string, error) { return "", fmt.Errorf("step failure") } -func simpleWorkflowWithStepError(ctx context.Context, input string) (string, error) { - return RunAsStep(ctx, dbos, simpleStepError, input) +func simpleWorkflowWithStepError(dbosCtx DBOSContext, input string) (string, error) { + return RunAsStep(dbosCtx, simpleStepError, input) } // idempotencyWorkflow increments a global counter and returns the input -func incrementCounter(_ context.Context, value int64) (int64, error) { - idempotencyCounter += value +func incrementCounter(_ context.Context, value ...any) (int64, error) { + if len(value) == 0 { + return 0, fmt.Errorf("expected int64 value") + } + val, ok := value[0].(int64) + if !ok { + return 0, fmt.Errorf("expected int64, got %T", value[0]) + } + fmt.Println("incrementCounter called with value:", val) + idempotencyCounter += val + fmt.Println("Current idempotency counter:", idempotencyCounter) return idempotencyCounter, nil } @@ -58,65 +66,68 @@ func incrementCounter(_ context.Context, value int64) (int64, error) { type workflowStruct struct{} // Pointer receiver method -func (w *workflowStruct) simpleWorkflow(ctx context.Context, input string) (string, error) { - return simpleWorkflow(ctx, input) +func (w *workflowStruct) simpleWorkflow(dbosCtx DBOSContext, input string) (string, error) { + return simpleWorkflow(dbosCtx, input) } // Value receiver method on the same struct -func (w workflowStruct) simpleWorkflowValue(ctx context.Context, input string) (string, error) { +func (w workflowStruct) simpleWorkflowValue(dbosCtx DBOSContext, input string) (string, error) { return input + "-value", nil } // interface for workflow methods type TestWorkflowInterface interface { - Execute(ctx context.Context, input string) (string, error) + Execute(dbosCtx DBOSContext, input string) (string, error) } type workflowImplementation struct { field string } -func (w *workflowImplementation) Execute(ctx context.Context, input string) (string, error) { +func (w *workflowImplementation) Execute(dbosCtx DBOSContext, input string) (string, error) { return input + "-" + w.field + "-interface", nil } // Generic workflow function -func Identity[T any](ctx context.Context, in T) (T, error) { +func Identity[T any](dbosCtx DBOSContext, in T) (T, error) { return in, nil } -func TestWorkflowsWrapping(t *testing.T) { +func TestWorkflowsRegistration(t *testing.T) { executor := setupDBOS(t) + dbosCtx := executor // Setup workflows with executor - simpleWf := RegisterWorkflow(executor, simpleWorkflow) - simpleWfError := RegisterWorkflow(executor, simpleWorkflowError) - simpleWfWithStep := RegisterWorkflow(executor, simpleWorkflowWithStep) - simpleWfWithStepError := RegisterWorkflow(executor, simpleWorkflowWithStepError) + RegisterWorkflow(dbosCtx, simpleWorkflow) + RegisterWorkflow(dbosCtx, simpleWorkflowError) + RegisterWorkflow(dbosCtx, simpleWorkflowWithStep) + RegisterWorkflow(dbosCtx, simpleWorkflowWithStepError) // struct methods s := workflowStruct{} - simpleWfStruct := RegisterWorkflow(executor, s.simpleWorkflow) - simpleWfValue := RegisterWorkflow(executor, s.simpleWorkflowValue) + RegisterWorkflow(dbosCtx, s.simpleWorkflow) + RegisterWorkflow(dbosCtx, s.simpleWorkflowValue) // interface method workflow workflowIface := TestWorkflowInterface(&workflowImplementation{ field: "example", }) - simpleWfIface := RegisterWorkflow(executor, workflowIface.Execute) + RegisterWorkflow(dbosCtx, workflowIface.Execute) // Generic workflow - wfInt := RegisterWorkflow(executor, Identity[string]) // FIXME make this an int eventually + RegisterWorkflow(dbosCtx, Identity[int]) // Closure with captured state prefix := "hello-" - wfClose := RegisterWorkflow(executor, func(ctx context.Context, in string) (string, error) { + closureWorkflow := func(dbosCtx DBOSContext, in string) (string, error) { return prefix + in, nil - }) + } + RegisterWorkflow(dbosCtx, closureWorkflow) // Anonymous workflow - anonymousWf := RegisterWorkflow(executor, func(ctx context.Context, in string) (string, error) { + anonymousWorkflow := func(dbosCtx DBOSContext, in string) (string, error) { return "anonymous-" + in, nil - }) + } + RegisterWorkflow(dbosCtx, anonymousWorkflow) type testCase struct { name string - workflowFunc func(context.Context, string, ...WorkflowOption) (any, error) + workflowFunc func(DBOSContext, string, ...WorkflowOption) (any, error) input string expectedResult any expectError bool @@ -126,13 +137,13 @@ func TestWorkflowsWrapping(t *testing.T) { tests := []testCase{ { name: "SimpleWorkflow", - workflowFunc: func(ctx context.Context, input string, opts ...WorkflowOption) (any, error) { - handle, err := simpleWf(ctx, input, opts...) + workflowFunc: func(dbosCtx DBOSContext, input string, opts ...WorkflowOption) (any, error) { + handle, err := RunAsWorkflow(dbosCtx, simpleWorkflow, input, opts...) if err != nil { return nil, err } - result, err := handle.GetResult(ctx) - _, err2 := handle.GetResult(ctx) + result, err := handle.GetResult() + _, err2 := handle.GetResult() if err2 == nil { t.Fatal("Second call to GetResult should return an error") } @@ -148,12 +159,12 @@ func TestWorkflowsWrapping(t *testing.T) { }, { name: "SimpleWorkflowError", - workflowFunc: func(ctx context.Context, input string, opts ...WorkflowOption) (any, error) { - handle, err := simpleWfError(ctx, input, opts...) + workflowFunc: func(dbosCtx DBOSContext, input string, opts ...WorkflowOption) (any, error) { + handle, err := RunAsWorkflow(dbosCtx, simpleWorkflowError, input, opts...) if err != nil { return nil, err } - return handle.GetResult(ctx) + return handle.GetResult() }, input: "echo", expectError: true, @@ -161,12 +172,12 @@ func TestWorkflowsWrapping(t *testing.T) { }, { name: "SimpleWorkflowWithStep", - workflowFunc: func(ctx context.Context, input string, opts ...WorkflowOption) (any, error) { - handle, err := simpleWfWithStep(ctx, input, opts...) + workflowFunc: func(dbosCtx DBOSContext, input string, opts ...WorkflowOption) (any, error) { + handle, err := RunAsWorkflow(dbosCtx, simpleWorkflowWithStep, input, opts...) if err != nil { return nil, err } - return handle.GetResult(ctx) + return handle.GetResult() }, input: "echo", expectedResult: "from step", @@ -174,12 +185,12 @@ func TestWorkflowsWrapping(t *testing.T) { }, { name: "SimpleWorkflowStruct", - workflowFunc: func(ctx context.Context, input string, opts ...WorkflowOption) (any, error) { - handle, err := simpleWfStruct(ctx, input, opts...) + workflowFunc: func(dbosCtx DBOSContext, input string, opts ...WorkflowOption) (any, error) { + handle, err := RunAsWorkflow(dbosCtx, s.simpleWorkflow, input, opts...) if err != nil { return nil, err } - return handle.GetResult(ctx) + return handle.GetResult() }, input: "echo", expectedResult: "echo", @@ -187,12 +198,12 @@ func TestWorkflowsWrapping(t *testing.T) { }, { name: "ValueReceiverWorkflow", - workflowFunc: func(ctx context.Context, input string, opts ...WorkflowOption) (any, error) { - handle, err := simpleWfValue(ctx, input, opts...) + workflowFunc: func(dbosCtx DBOSContext, input string, opts ...WorkflowOption) (any, error) { + handle, err := RunAsWorkflow(dbosCtx, s.simpleWorkflowValue, input, opts...) if err != nil { return nil, err } - return handle.GetResult(ctx) + return handle.GetResult() }, input: "echo", expectedResult: "echo-value", @@ -200,12 +211,12 @@ func TestWorkflowsWrapping(t *testing.T) { }, { name: "interfaceMethodWorkflow", - workflowFunc: func(ctx context.Context, input string, opts ...WorkflowOption) (any, error) { - handle, err := simpleWfIface(ctx, input, opts...) + workflowFunc: func(dbosCtx DBOSContext, input string, opts ...WorkflowOption) (any, error) { + handle, err := RunAsWorkflow(dbosCtx, workflowIface.Execute, input, opts...) if err != nil { return nil, err } - return handle.GetResult(ctx) + return handle.GetResult() }, input: "echo", expectedResult: "echo-example-interface", @@ -213,26 +224,25 @@ func TestWorkflowsWrapping(t *testing.T) { }, { name: "GenericWorkflow", - workflowFunc: func(ctx context.Context, input string, opts ...WorkflowOption) (any, error) { - // For generic workflow, we need to convert string to int for testing - handle, err := wfInt(ctx, "42", opts...) // FIXME for now this returns a string because sys db accepts this + workflowFunc: func(dbosCtx DBOSContext, input string, opts ...WorkflowOption) (any, error) { + handle, err := RunAsWorkflow(dbosCtx, Identity, 42, opts...) if err != nil { return nil, err } - return handle.GetResult(ctx) + return handle.GetResult() }, input: "42", // input not used in this case - expectedResult: "42", // FIXME make this an int eventually + expectedResult: 42, expectError: false, }, { name: "ClosureWithCapturedState", - workflowFunc: func(ctx context.Context, input string, opts ...WorkflowOption) (any, error) { - handle, err := wfClose(ctx, input, opts...) + workflowFunc: func(dbosCtx DBOSContext, input string, opts ...WorkflowOption) (any, error) { + handle, err := RunAsWorkflow(dbosCtx, closureWorkflow, input, opts...) if err != nil { return nil, err } - return handle.GetResult(ctx) + return handle.GetResult() }, input: "world", expectedResult: "hello-world", @@ -240,12 +250,12 @@ func TestWorkflowsWrapping(t *testing.T) { }, { name: "AnonymousClosure", - workflowFunc: func(ctx context.Context, input string, opts ...WorkflowOption) (any, error) { - handle, err := anonymousWf(ctx, input, opts...) + workflowFunc: func(dbosCtx DBOSContext, input string, opts ...WorkflowOption) (any, error) { + handle, err := RunAsWorkflow(dbosCtx, anonymousWorkflow, input, opts...) if err != nil { return nil, err } - return handle.GetResult(ctx) + return handle.GetResult() }, input: "test", expectedResult: "anonymous-test", @@ -253,12 +263,12 @@ func TestWorkflowsWrapping(t *testing.T) { }, { name: "SimpleWorkflowWithStepError", - workflowFunc: func(ctx context.Context, input string, opts ...WorkflowOption) (any, error) { - handle, err := simpleWfWithStepError(ctx, input, opts...) + workflowFunc: func(dbosCtx DBOSContext, input string, opts ...WorkflowOption) (any, error) { + handle, err := RunAsWorkflow(dbosCtx, simpleWorkflowWithStepError, input, opts...) if err != nil { return nil, err } - return handle.GetResult(ctx) + return handle.GetResult() }, input: "echo", expectError: true, @@ -268,7 +278,7 @@ func TestWorkflowsWrapping(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - result, err := tc.workflowFunc(context.Background(), tc.input, WithWorkflowID(uuid.NewString())) + result, err := tc.workflowFunc(executor, tc.input, WithWorkflowID(uuid.NewString())) if tc.expectError { if err == nil { @@ -289,50 +299,56 @@ func TestWorkflowsWrapping(t *testing.T) { } } -func stepWithinAStep(ctx context.Context, input string) (string, error) { - return RunAsStep(ctx, dbos, simpleStep, input) +func stepWithinAStep(ctx context.Context, input ...any) (string, error) { + return simpleStep(ctx, input...) } -func stepWithinAStepWorkflow(ctx context.Context, input string) (string, error) { - return RunAsStep(ctx, dbos, stepWithinAStep, input) +func stepWithinAStepWorkflow(dbosCtx DBOSContext, input string) (string, error) { + return RunAsStep(dbosCtx, stepWithinAStep, input) } // Global counter for retry testing var stepRetryAttemptCount int -func stepRetryAlwaysFailsStep(ctx context.Context, input string) (string, error) { +func stepRetryAlwaysFailsStep(ctx context.Context, input ...any) (string, error) { stepRetryAttemptCount++ return "", fmt.Errorf("always fails - attempt %d", stepRetryAttemptCount) } var stepIdempotencyCounter int -func stepIdempotencyTest(ctx context.Context, input string) (string, error) { +func stepIdempotencyTest(ctx context.Context, input ...any) (string, error) { stepIdempotencyCounter++ - return input, nil + if len(input) > 0 { + if str, ok := input[0].(string); ok { + return str, nil + } + } + return "", nil } -func stepRetryWorkflow(ctx context.Context, input string) (string, error) { - RunAsStep(ctx, dbos, stepIdempotencyTest, input) - return RunAsStep(ctx, dbos, stepRetryAlwaysFailsStep, input, - WithStepMaxRetries(5), - WithBackoffFactor(2.0), - WithBaseInterval(1*time.Millisecond), - WithMaxInterval(10*time.Millisecond)) +func stepRetryWorkflow(dbosCtx DBOSContext, input string) (string, error) { + RunAsStep(dbosCtx, stepIdempotencyTest, 1) + stepCtx := WithValue(dbosCtx, StepParamsKey, &StepParams{ + MaxRetries: 5, + BaseInterval: 1 * time.Millisecond, + MaxInterval: 10 * time.Millisecond, + }) + + return RunAsStep[string](stepCtx, stepRetryAlwaysFailsStep, input) } +// TODO: step params func TestSteps(t *testing.T) { - executor := setupDBOS(t) + dbosCtx := setupDBOS(t) // Create workflows with executor - stepWithinAStepWf := RegisterWorkflow(executor, stepWithinAStepWorkflow) - stepRetryWf := RegisterWorkflow(executor, stepRetryWorkflow) + RegisterWorkflow(dbosCtx, stepWithinAStepWorkflow) + RegisterWorkflow(dbosCtx, stepRetryWorkflow) t.Run("StepsMustRunInsideWorkflows", func(t *testing.T) { - ctx := context.Background() - // Attempt to run a step outside of a workflow context - _, err := RunAsStep(ctx, dbos, simpleStep, "test") + _, err := RunAsStep(dbosCtx, simpleStep, "test") if err == nil { t.Fatal("expected error when running step outside of workflow context, but got none") } @@ -355,11 +371,11 @@ func TestSteps(t *testing.T) { }) t.Run("StepWithinAStepAreJustFunctions", func(t *testing.T) { - handle, err := stepWithinAStepWf(context.Background(), "test") + handle, err := RunAsWorkflow(dbosCtx, stepWithinAStepWorkflow, "test") if err != nil { t.Fatal("failed to run step within a step:", err) } - result, err := handle.GetResult(context.Background()) + result, err := handle.GetResult() if err != nil { t.Fatal("failed to get result from step within a step:", err) } @@ -367,7 +383,7 @@ func TestSteps(t *testing.T) { t.Fatalf("expected result 'from step', got '%s'", result) } - steps, err := dbos.systemDB.GetWorkflowSteps(context.Background(), handle.GetWorkflowID()) + steps, err := dbosCtx.(*dbosContext).systemDB.GetWorkflowSteps(context.Background(), handle.GetWorkflowID()) if err != nil { t.Fatal("failed to list steps:", err) } @@ -382,12 +398,12 @@ func TestSteps(t *testing.T) { stepIdempotencyCounter = 0 // Execute the workflow - handle, err := stepRetryWf(context.Background(), "test") + handle, err := RunAsWorkflow(dbosCtx, stepRetryWorkflow, "test") if err != nil { t.Fatal("failed to start retry workflow:", err) } - _, err = handle.GetResult(context.Background()) + _, err = handle.GetResult() if err == nil { t.Fatal("expected error from failing workflow but got none") } @@ -422,7 +438,7 @@ func TestSteps(t *testing.T) { } // Verify that the failed step was still recorded in the database - steps, err := dbos.systemDB.GetWorkflowSteps(context.Background(), handle.GetWorkflowID()) + steps, err := dbosCtx.(*dbosContext).systemDB.GetWorkflowSteps(context.Background(), handle.GetWorkflowID()) if err != nil { t.Fatal("failed to get workflow steps:", err) } @@ -452,11 +468,11 @@ func TestSteps(t *testing.T) { // TODO Check timeouts behaviors for parents and children (e.g. awaited cancelled, etc) func TestChildWorkflow(t *testing.T) { - executor := setupDBOS(t) + dbosCtx := setupDBOS(t) // Create child workflows with executor - childWf := RegisterWorkflow(executor, func(ctx context.Context, i int) (string, error) { - workflowID, err := GetWorkflowID(ctx) + childWf := func(dbosCtx DBOSContext, i int) (string, error) { + workflowID, err := dbosCtx.GetWorkflowID() if err != nil { return "", fmt.Errorf("failed to get workflow ID: %v", err) } @@ -465,16 +481,17 @@ func TestChildWorkflow(t *testing.T) { return "", fmt.Errorf("expected parentWf workflow ID to be %s, got %s", expectedCurrentID, workflowID) } // XXX right now the steps of a child workflow start with an incremented step ID, because the first step ID is allocated to the child workflow - return RunAsStep(ctx, dbos, simpleStep, "") - }) + return RunAsStep(dbosCtx, simpleStep, "") + } + RegisterWorkflow(dbosCtx, childWf) - parentWf := RegisterWorkflow(executor, func(ctx context.Context, i int) (string, error) { - workflowID, err := GetWorkflowID(ctx) + parentWf := func(dbosCtx DBOSContext, i int) (string, error) { + workflowID, err := dbosCtx.GetWorkflowID() if err != nil { return "", fmt.Errorf("failed to get workflow ID: %v", err) } - childHandle, err := childWf(ctx, i) + childHandle, err := RunAsWorkflow(dbosCtx, childWf, i) if err != nil { return "", err } @@ -491,17 +508,18 @@ func TestChildWorkflow(t *testing.T) { if childWorkflowID != expectedChildID { return "", fmt.Errorf("expected childWf ID to be %s, got %s", expectedChildID, childWorkflowID) } - return childHandle.GetResult(ctx) - }) + return childHandle.GetResult() + } + RegisterWorkflow(dbosCtx, parentWf) - grandParentWf := RegisterWorkflow(executor, func(ctx context.Context, _ string) (string, error) { + grandParentWf := func(dbosCtx DBOSContext, _ string) (string, error) { for i := range 3 { - workflowID, err := GetWorkflowID(ctx) + workflowID, err := dbosCtx.GetWorkflowID() if err != nil { return "", fmt.Errorf("failed to get workflow ID: %v", err) } - childHandle, err := parentWf(ctx, i) + childHandle, err := RunAsWorkflow(dbosCtx, parentWf, i) if err != nil { return "", err } @@ -520,7 +538,7 @@ func TestChildWorkflow(t *testing.T) { } // Calling the child a second time should return a polling handle - childHandle, err = parentWf(ctx, i, WithWorkflowID(childHandle.GetWorkflowID())) + childHandle, err = RunAsWorkflow(dbosCtx, parentWf, i, WithWorkflowID(childHandle.GetWorkflowID())) if err != nil { return "", err } @@ -532,14 +550,15 @@ func TestChildWorkflow(t *testing.T) { } return "", nil - }) + } + RegisterWorkflow(dbosCtx, grandParentWf) t.Run("ChildWorkflowIDPattern", func(t *testing.T) { - h, err := grandParentWf(context.Background(), "") + h, err := RunAsWorkflow(dbosCtx, grandParentWf, "") if err != nil { t.Fatalf("failed to execute grand parent workflow: %v", err) } - _, err = h.GetResult(context.Background()) + _, err = h.GetResult() if err != nil { t.Fatalf("failed to get result from grand parent workflow: %v", err) } @@ -548,30 +567,30 @@ func TestChildWorkflow(t *testing.T) { // Idempotency workflows moved to test functions -func idempotencyWorkflow(ctx context.Context, input string) (string, error) { - incrementCounter(ctx, 1) +func idempotencyWorkflow(dbosCtx DBOSContext, input string) (string, error) { + RunAsStep(dbosCtx, incrementCounter, int64(1)) return input, nil } var blockingStepStopEvent *Event -func blockingStep(ctx context.Context, input string) (string, error) { +func blockingStep(ctx context.Context, input ...any) (string, error) { blockingStepStopEvent.Wait() return "", nil } var idempotencyWorkflowWithStepEvent *Event -func idempotencyWorkflowWithStep(ctx context.Context, input string) (int64, error) { - RunAsStep(ctx, dbos, incrementCounter, 1) +func idempotencyWorkflowWithStep(dbosCtx DBOSContext, input string) (int64, error) { + RunAsStep(dbosCtx, incrementCounter, int64(1)) idempotencyWorkflowWithStepEvent.Set() - RunAsStep(ctx, dbos, blockingStep, input) + RunAsStep(dbosCtx, blockingStep, input) return idempotencyCounter, nil } func TestWorkflowIdempotency(t *testing.T) { - executor := setupDBOS(t) - idempotencyWf := RegisterWorkflow(executor, idempotencyWorkflow) + dbosCtx := setupDBOS(t) + RegisterWorkflow(dbosCtx, idempotencyWorkflow) t.Run("WorkflowExecutedOnlyOnce", func(t *testing.T) { idempotencyCounter = 0 @@ -581,21 +600,21 @@ func TestWorkflowIdempotency(t *testing.T) { // Execute the same workflow twice with the same ID // First execution - handle1, err := idempotencyWf(context.Background(), input, WithWorkflowID(workflowID)) + handle1, err := RunAsWorkflow(dbosCtx, idempotencyWorkflow, input, WithWorkflowID(workflowID)) if err != nil { t.Fatalf("failed to execute workflow first time: %v", err) } - result1, err := handle1.GetResult(context.Background()) + result1, err := handle1.GetResult() if err != nil { t.Fatalf("failed to get result from first execution: %v", err) } // Second execution with the same workflow ID - handle2, err := idempotencyWf(context.Background(), input, WithWorkflowID(workflowID)) + handle2, err := RunAsWorkflow(dbosCtx, idempotencyWorkflow, input, WithWorkflowID(workflowID)) if err != nil { t.Fatalf("failed to execute workflow second time: %v", err) } - result2, err := handle2.GetResult(context.Background()) + result2, err := handle2.GetResult() if err != nil { t.Fatalf("failed to get result from second execution: %v", err) } @@ -620,7 +639,7 @@ func TestWorkflowIdempotency(t *testing.T) { func TestWorkflowRecovery(t *testing.T) { executor := setupDBOS(t) - idempotencyWfWithStep := RegisterWorkflow(executor, idempotencyWorkflowWithStep) + RegisterWorkflow(executor, idempotencyWorkflowWithStep) t.Run("RecoveryResumeWhereItLeftOff", func(t *testing.T) { // Reset the global counter idempotencyCounter = 0 @@ -629,7 +648,7 @@ func TestWorkflowRecovery(t *testing.T) { input := "recovery-test" idempotencyWorkflowWithStepEvent = NewEvent() blockingStepStopEvent = NewEvent() - handle1, err := idempotencyWfWithStep(context.Background(), input) + handle1, err := RunAsWorkflow(executor, idempotencyWorkflowWithStep, input) if err != nil { t.Fatalf("failed to execute workflow first time: %v", err) } @@ -637,7 +656,7 @@ func TestWorkflowRecovery(t *testing.T) { idempotencyWorkflowWithStepEvent.Wait() // Wait for the first step to complete. The second spins forever. // Run recovery for pending workflows with "local" executor - recoveredHandles, err := recoverPendingWorkflows(context.Background(), []string{"local"}) + recoveredHandles, err := recoverPendingWorkflows(executor.(*dbosContext), []string{"local"}) if err != nil { t.Fatalf("failed to recover pending workflows: %v", err) } @@ -666,7 +685,7 @@ func TestWorkflowRecovery(t *testing.T) { } // Using ListWorkflows, retrieve the status of the workflow - workflows, err := dbos.systemDB.ListWorkflows(context.Background(), ListWorkflowsDBInput{ + workflows, err := executor.(*dbosContext).systemDB.ListWorkflows(context.Background(), listWorkflowsDBInput{ workflowIDs: []string{handle1.GetWorkflowID()}, }) if err != nil { @@ -686,7 +705,7 @@ func TestWorkflowRecovery(t *testing.T) { // unlock the workflow & wait for result blockingStepStopEvent.Set() // This will allow the blocking step to complete - result, err := recoveredHandle.GetResult(context.Background()) + result, err := recoveredHandle.GetResult() if err != nil { t.Fatalf("failed to get result from recovered handle: %v", err) } @@ -703,7 +722,7 @@ var ( recoveryCount int64 ) -func deadLetterQueueWorkflow(ctx context.Context, input string) (int, error) { +func deadLetterQueueWorkflow(ctx DBOSContext, input string) (int, error) { recoveryCount++ fmt.Printf("Dead letter queue workflow started, recovery count: %d\n", recoveryCount) deadLetterQueueStartEvent.Set() @@ -711,15 +730,15 @@ func deadLetterQueueWorkflow(ctx context.Context, input string) (int, error) { return 0, nil } -func infiniteDeadLetterQueueWorkflow(ctx context.Context, input string) (int, error) { +func infiniteDeadLetterQueueWorkflow(ctx DBOSContext, input string) (int, error) { deadLetterQueueStartEvent.Set() deadLetterQueueEvent.Wait() return 0, nil } func TestWorkflowDeadLetterQueue(t *testing.T) { executor := setupDBOS(t) - deadLetterQueueWf := RegisterWorkflow(executor, deadLetterQueueWorkflow, WithMaxRetries(maxRecoveryAttempts)) - infiniteDeadLetterQueueWf := RegisterWorkflow(executor, infiniteDeadLetterQueueWorkflow, WithMaxRetries(-1)) // A negative value means infinite retries + RegisterWorkflow(executor, deadLetterQueueWorkflow, WithMaxRetries(maxRecoveryAttempts)) + RegisterWorkflow(executor, infiniteDeadLetterQueueWorkflow, WithMaxRetries(-1)) // A negative value means infinite retries t.Run("DeadLetterQueueBehavior", func(t *testing.T) { deadLetterQueueEvent = NewEvent() @@ -728,7 +747,7 @@ func TestWorkflowDeadLetterQueue(t *testing.T) { // Start a workflow that blocks forever wfID := uuid.NewString() - handle, err := deadLetterQueueWf(context.Background(), "test", WithWorkflowID(wfID)) + handle, err := RunAsWorkflow(executor, deadLetterQueueWorkflow, "test", WithWorkflowID(wfID)) if err != nil { t.Fatalf("failed to start dead letter queue workflow: %v", err) } @@ -737,7 +756,7 @@ func TestWorkflowDeadLetterQueue(t *testing.T) { // Attempt to recover the blocked workflow the maximum number of times for i := range maxRecoveryAttempts { - _, err := recoverPendingWorkflows(context.Background(), []string{"local"}) + _, err := recoverPendingWorkflows(executor.(*dbosContext), []string{"local"}) if err != nil { t.Fatalf("failed to recover pending workflows on attempt %d: %v", i+1, err) } @@ -750,7 +769,7 @@ func TestWorkflowDeadLetterQueue(t *testing.T) { } // Verify an additional attempt throws a DLQ error and puts the workflow in the DLQ status - _, err = recoverPendingWorkflows(context.Background(), []string{"local"}) + _, err = recoverPendingWorkflows(executor.(*dbosContext), []string{"local"}) if err == nil { t.Fatal("expected dead letter queue error but got none") } @@ -773,7 +792,7 @@ func TestWorkflowDeadLetterQueue(t *testing.T) { } // Verify that attempting to start a workflow with the same ID throws a DLQ error - _, err = deadLetterQueueWf(context.Background(), "test", WithWorkflowID(wfID)) + _, err = RunAsWorkflow(executor, deadLetterQueueWorkflow, "test", WithWorkflowID(wfID)) if err == nil { t.Fatal("expected dead letter queue error when restarting workflow with same ID but got none") } @@ -793,7 +812,7 @@ func TestWorkflowDeadLetterQueue(t *testing.T) { resumedHandle, err := ... // Recover pending workflows again - should work without error - _, err = recoverPendingWorkflows(context.Background(), []string{"local"}) + _, err = recoverPendingWorkflows(executor.(*dbosContext), []string{"local"}) if err != nil { t.Fatalf("failed to recover pending workflows after resume: %v", err) } @@ -827,7 +846,7 @@ func TestWorkflowDeadLetterQueue(t *testing.T) { // Verify that retries of a completed workflow do not raise the DLQ exception for i := 0; i < maxRecoveryAttempts*2; i++ { - _, err = deadLetterQueueWf(context.Background(), "test", WithWorkflowID(wfID)) + _, err = RunAsWorkflow(executor, deadLetterQueueWorkflow, "test", WithWorkflowID(wfID)) if err != nil { t.Fatalf("unexpected error when retrying completed workflow: %v", err) } @@ -842,7 +861,7 @@ func TestWorkflowDeadLetterQueue(t *testing.T) { // Verify that a workflow with MaxRetries=0 (infinite retries) is retried infinitely wfID := uuid.NewString() - handle, err := infiniteDeadLetterQueueWf(context.Background(), "test", WithWorkflowID(wfID)) + handle, err := RunAsWorkflow(executor, infiniteDeadLetterQueueWorkflow, "test", WithWorkflowID(wfID)) if err != nil { t.Fatalf("failed to start infinite dead letter queue workflow: %v", err) } @@ -852,7 +871,7 @@ func TestWorkflowDeadLetterQueue(t *testing.T) { // Attempt to recover the blocked workflow many times (should never fail) handles := []WorkflowHandle[any]{} for i := range _DEFAULT_MAX_RECOVERY_ATTEMPTS * 2 { - recoveredHandles, err := recoverPendingWorkflows(context.Background(), []string{"local"}) + recoveredHandles, err := recoverPendingWorkflows(executor.(*dbosContext), []string{"local"}) if err != nil { t.Fatalf("failed to recover pending workflows on attempt %d: %v", i+1, err) } @@ -864,7 +883,7 @@ func TestWorkflowDeadLetterQueue(t *testing.T) { // Complete the workflow deadLetterQueueEvent.Set() - result, err := handle.GetResult(context.Background()) + result, err := handle.GetResult() if err != nil { t.Fatalf("failed to get result from infinite dead letter queue workflow: %v", err) } @@ -874,7 +893,7 @@ func TestWorkflowDeadLetterQueue(t *testing.T) { // Wait for all handles to complete for i, h := range handles { - result, err := h.GetResult(context.Background()) + result, err := h.GetResult() if err != nil { t.Fatalf("failed to get result from handle %d: %v", i, err) } @@ -892,17 +911,13 @@ var ( func TestScheduledWorkflows(t *testing.T) { executor := setupDBOS(t) - _ = RegisterWorkflow(executor, func(ctx context.Context, scheduledTime time.Time) (string, error) { + RegisterWorkflow(executor, func(ctx DBOSContext, scheduledTime time.Time) (string, error) { startTime := time.Now() counter++ if counter == 10 { - return "", fmt.Errorf("counter reached 100, stopping workflow") - } - select { - case counter1Ch <- startTime: - default: + return "", fmt.Errorf("counter reached 10, stopping workflow") } - return fmt.Sprintf("Scheduled workflow scheduled at time %v and executed at time %v", scheduledTime, startTime), nil + return fmt.Sprintf("Scheduled workflow executed at %v", startTime), nil }, WithSchedule("* * * * * *")) // Every second // Helper function to collect execution times @@ -954,7 +969,7 @@ func TestScheduledWorkflows(t *testing.T) { // Stop the workflowScheduler and check if it stops executing currentCounter := counter - executor.GetWorkflowScheduler().Stop() + executor.(*dbosContext).getWorkflowScheduler().Stop() time.Sleep(3 * time.Second) // Wait a bit to ensure no more executions if counter >= currentCounter+2 { t.Fatalf("Scheduled workflow continued executing after stopping scheduler: %d (expected < %d)", counter, currentCounter+2) @@ -976,39 +991,43 @@ type sendWorkflowInput struct { Topic string } -func sendWorkflow(ctx context.Context, input sendWorkflowInput) (string, error) { - err := Send(ctx, dbos, WorkflowSendInput[string]{DestinationID: input.DestinationID, Topic: input.Topic, Message: "message1"}) +func sendWorkflow(ctx DBOSContext, input sendWorkflowInput) (string, error) { + err := Send(ctx, WorkflowSendInput[string]{ + DestinationID: input.DestinationID, + Topic: input.Topic, + Message: "message1", + }) if err != nil { return "", err } - err = Send(ctx, dbos, WorkflowSendInput[string]{DestinationID: input.DestinationID, Topic: input.Topic, Message: "message2"}) + err = Send(ctx, WorkflowSendInput[string]{DestinationID: input.DestinationID, Topic: input.Topic, Message: "message2"}) if err != nil { return "", err } - err = Send(ctx, dbos, WorkflowSendInput[string]{DestinationID: input.DestinationID, Topic: input.Topic, Message: "message3"}) + err = Send(ctx, WorkflowSendInput[string]{DestinationID: input.DestinationID, Topic: input.Topic, Message: "message3"}) if err != nil { return "", err } return "", nil } -func receiveWorkflow(ctx context.Context, topic string) (string, error) { - msg1, err := Recv[string](ctx, dbos, WorkflowRecvInput{Topic: topic, Timeout: 3 * time.Second}) +func receiveWorkflow(ctx DBOSContext, topic string) (string, error) { + msg1, err := Recv[string](ctx, WorkflowRecvInput{Topic: topic, Timeout: 3 * time.Second}) if err != nil { return "", err } - msg2, err := Recv[string](ctx, dbos, WorkflowRecvInput{Topic: topic, Timeout: 3 * time.Second}) + msg2, err := Recv[string](ctx, WorkflowRecvInput{Topic: topic, Timeout: 3 * time.Second}) if err != nil { return "", err } - msg3, err := Recv[string](ctx, dbos, WorkflowRecvInput{Topic: topic, Timeout: 3 * time.Second}) + msg3, err := Recv[string](ctx, WorkflowRecvInput{Topic: topic, Timeout: 3 * time.Second}) if err != nil { return "", err } return msg1 + "-" + msg2 + "-" + msg3, nil } -func receiveWorkflowCoordinated(ctx context.Context, input struct { +func receiveWorkflowCoordinated(ctx DBOSContext, input struct { Topic string i int }) (string, error) { @@ -1020,25 +1039,25 @@ func receiveWorkflowCoordinated(ctx context.Context, input struct { concurrentRecvStartEvent.Wait() // Do a single Recv call with timeout - msg, err := Recv[string](ctx, dbos, WorkflowRecvInput{Topic: input.Topic, Timeout: 3 * time.Second}) + msg, err := Recv[string](ctx, WorkflowRecvInput{Topic: input.Topic, Timeout: 3 * time.Second}) if err != nil { return "", err } return msg, nil } -func sendStructWorkflow(ctx context.Context, input sendWorkflowInput) (string, error) { +func sendStructWorkflow(ctx DBOSContext, input sendWorkflowInput) (string, error) { testStruct := sendRecvType{Value: "test-struct-value"} - err := Send(ctx, dbos, WorkflowSendInput[sendRecvType]{DestinationID: input.DestinationID, Topic: input.Topic, Message: testStruct}) + err := Send(ctx, WorkflowSendInput[sendRecvType]{DestinationID: input.DestinationID, Topic: input.Topic, Message: testStruct}) return "", err } -func receiveStructWorkflow(ctx context.Context, topic string) (sendRecvType, error) { - return Recv[sendRecvType](ctx, dbos, WorkflowRecvInput{Topic: topic, Timeout: 3 * time.Second}) +func receiveStructWorkflow(ctx DBOSContext, topic string) (sendRecvType, error) { + return Recv[sendRecvType](ctx, WorkflowRecvInput{Topic: topic, Timeout: 3 * time.Second}) } -func sendIdempotencyWorkflow(ctx context.Context, input sendWorkflowInput) (string, error) { - err := Send(ctx, dbos, WorkflowSendInput[string]{DestinationID: input.DestinationID, Topic: input.Topic, Message: "m1"}) +func sendIdempotencyWorkflow(ctx DBOSContext, input sendWorkflowInput) (string, error) { + err := Send(ctx, WorkflowSendInput[string]{DestinationID: input.DestinationID, Topic: input.Topic, Message: "m1"}) if err != nil { return "", err } @@ -1046,8 +1065,8 @@ func sendIdempotencyWorkflow(ctx context.Context, input sendWorkflowInput) (stri return "idempotent-send-completed", nil } -func receiveIdempotencyWorkflow(ctx context.Context, topic string) (string, error) { - msg, err := Recv[string](ctx, dbos, WorkflowRecvInput{Topic: topic, Timeout: 3 * time.Second}) +func receiveIdempotencyWorkflow(ctx DBOSContext, topic string) (string, error) { + msg, err := Recv[string](ctx, WorkflowRecvInput{Topic: topic, Timeout: 3 * time.Second}) if err != nil { return "", err } @@ -1056,20 +1075,20 @@ func receiveIdempotencyWorkflow(ctx context.Context, topic string) (string, erro return msg, nil } -func stepThatCallsSend(ctx context.Context, input sendWorkflowInput) (string, error) { - err := Send(ctx, dbos, WorkflowSendInput[string]{ - DestinationID: input.DestinationID, - Topic: input.Topic, - Message: "message-from-step", - }) - if err != nil { - return "", err +func stepThatCallsSend(ctx context.Context, input ...any) (string, error) { + if len(input) == 0 { + return "", fmt.Errorf("expected sendWorkflowInput") + } + _, ok := input[0].(sendWorkflowInput) + if !ok { + return "", fmt.Errorf("expected sendWorkflowInput, got %T", input[0]) } - return "send-completed", nil + // Note: Send cannot be called from within steps, this should fail + return "", fmt.Errorf("Send cannot be called from within a step") } -func workflowThatCallsSendInStep(ctx context.Context, input sendWorkflowInput) (string, error) { - return RunAsStep(ctx, dbos, stepThatCallsSend, input) +func workflowThatCallsSendInStep(ctx DBOSContext, input sendWorkflowInput) (string, error) { + return RunAsStep(ctx, stepThatCallsSend, input) } type sendRecvType struct { @@ -1080,37 +1099,37 @@ func TestSendRecv(t *testing.T) { executor := setupDBOS(t) // Register all send/recv workflows with executor - sendWf := RegisterWorkflow(executor, sendWorkflow) - receiveWf := RegisterWorkflow(executor, receiveWorkflow) - receiveWfCoordinated := RegisterWorkflow(executor, receiveWorkflowCoordinated) - sendStructWf := RegisterWorkflow(executor, sendStructWorkflow) - receiveStructWf := RegisterWorkflow(executor, receiveStructWorkflow) - sendIdempotencyWf := RegisterWorkflow(executor, sendIdempotencyWorkflow) - recvIdempotencyWf := RegisterWorkflow(executor, receiveIdempotencyWorkflow) - sendWithinStepWf := RegisterWorkflow(executor, workflowThatCallsSendInStep) + RegisterWorkflow(executor, sendWorkflow) + RegisterWorkflow(executor, receiveWorkflow) + RegisterWorkflow(executor, receiveWorkflowCoordinated) + RegisterWorkflow(executor, sendStructWorkflow) + RegisterWorkflow(executor, receiveStructWorkflow) + RegisterWorkflow(executor, sendIdempotencyWorkflow) + RegisterWorkflow(executor, receiveIdempotencyWorkflow) + RegisterWorkflow(executor, workflowThatCallsSendInStep) t.Run("SendRecvSuccess", func(t *testing.T) { // Start the receive workflow - receiveHandle, err := receiveWf(context.Background(), "test-topic") + receiveHandle, err := RunAsWorkflow(executor, receiveWorkflow, "test-topic") if err != nil { t.Fatalf("failed to start receive workflow: %v", err) } // Send a message to the receive workflow - handle, err := sendWf(context.Background(), sendWorkflowInput{ + handle, err := RunAsWorkflow(executor, sendWorkflow, sendWorkflowInput{ DestinationID: receiveHandle.GetWorkflowID(), Topic: "test-topic", }) if err != nil { t.Fatalf("failed to send message: %v", err) } - _, err = handle.GetResult(context.Background()) + _, err = handle.GetResult() if err != nil { t.Fatalf("failed to get result from send workflow: %v", err) } start := time.Now() - result, err := receiveHandle.GetResult(context.Background()) + result, err := receiveHandle.GetResult() if err != nil { t.Fatalf("failed to get result from receive workflow: %v", err) } @@ -1125,13 +1144,13 @@ func TestSendRecv(t *testing.T) { t.Run("SendRecvCustomStruct", func(t *testing.T) { // Start the receive workflow - receiveHandle, err := receiveStructWf(context.Background(), "struct-topic") + receiveHandle, err := RunAsWorkflow(executor, receiveStructWorkflow, "struct-topic") if err != nil { t.Fatalf("failed to start receive workflow: %v", err) } // Send the struct to the receive workflow - sendHandle, err := sendStructWf(context.Background(), sendWorkflowInput{ + sendHandle, err := RunAsWorkflow(executor, sendStructWorkflow, sendWorkflowInput{ DestinationID: receiveHandle.GetWorkflowID(), Topic: "struct-topic", }) @@ -1139,13 +1158,13 @@ func TestSendRecv(t *testing.T) { t.Fatalf("failed to send struct: %v", err) } - _, err = sendHandle.GetResult(context.Background()) + _, err = sendHandle.GetResult() if err != nil { t.Fatalf("failed to get result from send workflow: %v", err) } // Get the result from receive workflow - result, err := receiveHandle.GetResult(context.Background()) + result, err := receiveHandle.GetResult() if err != nil { t.Fatalf("failed to get result from receive workflow: %v", err) } @@ -1161,7 +1180,7 @@ func TestSendRecv(t *testing.T) { destUUID := uuid.NewString() // Send to non-existent UUID should fail - handle, err := sendWf(context.Background(), sendWorkflowInput{ + handle, err := RunAsWorkflow(executor, sendWorkflow, sendWorkflowInput{ DestinationID: destUUID, Topic: "testtopic", }) @@ -1169,7 +1188,7 @@ func TestSendRecv(t *testing.T) { t.Fatalf("failed to start send workflow: %v", err) } - _, err = handle.GetResult(context.Background()) + _, err = handle.GetResult() if err == nil { t.Fatal("expected error when sending to non-existent UUID but got none") } @@ -1191,11 +1210,11 @@ func TestSendRecv(t *testing.T) { t.Run("RecvTimeout", func(t *testing.T) { // Create a receive workflow that tries to receive a message but no send happens - receiveHandle, err := receiveWf(context.Background(), "timeout-test-topic") + receiveHandle, err := RunAsWorkflow(executor, receiveWorkflow, "timeout-test-topic") if err != nil { t.Fatalf("failed to start receive workflow: %v", err) } - result, err := receiveHandle.GetResult(context.Background()) + result, err := receiveHandle.GetResult() if result != "--" { t.Fatalf("expected -- result on timeout, got '%s'", result) } @@ -1205,10 +1224,8 @@ func TestSendRecv(t *testing.T) { }) t.Run("RecvMustRunInsideWorkflows", func(t *testing.T) { - ctx := context.Background() - // Attempt to run Recv outside of a workflow context - _, err := Recv[string](ctx, dbos, WorkflowRecvInput{Topic: "test-topic", Timeout: 1 * time.Second}) + _, err := Recv[string](executor, WorkflowRecvInput{Topic: "test-topic", Timeout: 1 * time.Second}) if err == nil { t.Fatal("expected error when running Recv outside of workflow context, but got none") } @@ -1232,15 +1249,14 @@ func TestSendRecv(t *testing.T) { t.Run("SendOutsideWorkflow", func(t *testing.T) { // Start a receive workflow to have a valid destination - receiveHandle, err := receiveWf(context.Background(), "outside-workflow-topic") + receiveHandle, err := RunAsWorkflow(executor, receiveWorkflow, "outside-workflow-topic") if err != nil { t.Fatalf("failed to start receive workflow: %v", err) } // Send messages from outside a workflow context (should work now) - ctx := context.Background() for i := range 3 { - err = Send(ctx, dbos, WorkflowSendInput[string]{ + err = Send(executor, WorkflowSendInput[string]{ DestinationID: receiveHandle.GetWorkflowID(), Topic: "outside-workflow-topic", Message: fmt.Sprintf("message%d", i+1), @@ -1251,7 +1267,7 @@ func TestSendRecv(t *testing.T) { } // Verify the receive workflow gets all messages - result, err := receiveHandle.GetResult(context.Background()) + result, err := receiveHandle.GetResult() if err != nil { t.Fatalf("failed to get result from receive workflow: %v", err) } @@ -1261,13 +1277,13 @@ func TestSendRecv(t *testing.T) { }) t.Run("SendRecvIdempotency", func(t *testing.T) { // Start the receive workflow and wait for it to be ready - receiveHandle, err := recvIdempotencyWf(context.Background(), "idempotency-topic") + receiveHandle, err := RunAsWorkflow(executor, receiveIdempotencyWorkflow, "idempotency-topic") if err != nil { t.Fatalf("failed to start receive idempotency workflow: %v", err) } // Send the message to the receive workflow - sendHandle, err := sendIdempotencyWf(context.Background(), sendWorkflowInput{ + sendHandle, err := RunAsWorkflow(executor, sendIdempotencyWorkflow, sendWorkflowInput{ DestinationID: receiveHandle.GetWorkflowID(), Topic: "idempotency-topic", }) @@ -1279,21 +1295,21 @@ func TestSendRecv(t *testing.T) { receiveIdempotencyStartEvent.Wait() // Attempt recovering both workflows. There should be only 2 steps recorded after recovery. - recoveredHandles, err := recoverPendingWorkflows(context.Background(), []string{"local"}) + recoveredHandles, err := recoverPendingWorkflows(executor.(*dbosContext), []string{"local"}) if err != nil { t.Fatalf("failed to recover pending workflows: %v", err) } if len(recoveredHandles) != 2 { t.Fatalf("expected 2 recovered handles, got %d", len(recoveredHandles)) } - steps, err := dbos.systemDB.GetWorkflowSteps(context.Background(), sendHandle.GetWorkflowID()) + steps, err := executor.(*dbosContext).systemDB.GetWorkflowSteps(context.Background(), sendHandle.GetWorkflowID()) if err != nil { - t.Fatalf("failed to get steps for send idempotency workflow: %v", err) + t.Fatalf("failed to get workflow steps: %v", err) } if len(steps) != 1 { t.Fatalf("expected 1 step in send idempotency workflow, got %d", len(steps)) } - steps, err = dbos.systemDB.GetWorkflowSteps(context.Background(), receiveHandle.GetWorkflowID()) + steps, err = executor.(*dbosContext).systemDB.GetWorkflowSteps(context.Background(), receiveHandle.GetWorkflowID()) if err != nil { t.Fatalf("failed to get steps for receive idempotency workflow: %v", err) } @@ -1303,7 +1319,7 @@ func TestSendRecv(t *testing.T) { // Unblock the workflows to complete receiveIdempotencyStopEvent.Set() - result, err := receiveHandle.GetResult(context.Background()) + result, err := receiveHandle.GetResult() if err != nil { t.Fatalf("failed to get result from receive idempotency workflow: %v", err) } @@ -1311,7 +1327,7 @@ func TestSendRecv(t *testing.T) { t.Fatalf("expected result to be 'm1', got '%s'", result) } sendIdempotencyEvent.Set() - result, err = sendHandle.GetResult(context.Background()) + result, err = sendHandle.GetResult() if err != nil { t.Fatalf("failed to get result from send idempotency workflow: %v", err) } @@ -1322,13 +1338,13 @@ func TestSendRecv(t *testing.T) { t.Run("SendCannotBeCalledWithinStep", func(t *testing.T) { // Start a receive workflow to have a valid destination - receiveHandle, err := receiveWf(context.Background(), "send-within-step-topic") + receiveHandle, err := RunAsWorkflow(executor, receiveWorkflow, "send-within-step-topic") if err != nil { t.Fatalf("failed to start receive workflow: %v", err) } // Execute the workflow that tries to call Send within a step - handle, err := sendWithinStepWf(context.Background(), sendWorkflowInput{ + handle, err := RunAsWorkflow(executor, workflowThatCallsSendInStep, sendWorkflowInput{ DestinationID: receiveHandle.GetWorkflowID(), Topic: "send-within-step-topic", }) @@ -1337,7 +1353,7 @@ func TestSendRecv(t *testing.T) { } // Expect the workflow to fail with the specific error - _, err = handle.GetResult(context.Background()) + _, err = handle.GetResult() if err == nil { t.Fatal("expected error when calling Send within a step, but got none") } @@ -1359,7 +1375,7 @@ func TestSendRecv(t *testing.T) { } // Wait for the receive workflow to time out - result, err := receiveHandle.GetResult(context.Background()) + result, err := receiveHandle.GetResult() if err != nil { t.Fatalf("failed to get result from receive workflow: %v", err) } @@ -1382,7 +1398,7 @@ func TestSendRecv(t *testing.T) { // Start all receivers - they will signal when ready and wait for coordination for i := range numReceivers { concurrentRecvReadyEvents[i] = NewEvent() - receiveHandle, err := receiveWfCoordinated(context.Background(), struct { + receiveHandle, err := RunAsWorkflow(executor, receiveWorkflowCoordinated, struct { Topic string i int }{ @@ -1409,7 +1425,7 @@ func TestSendRecv(t *testing.T) { for i := range numReceivers { go func(index int) { defer wg.Done() - result, err := receiverHandles[index].GetResult(context.Background()) + result, err := receiverHandles[index].GetResult() if err != nil { errors <- err } else { @@ -1466,16 +1482,16 @@ type setEventWorkflowInput struct { Message string } -func setEventWorkflow(ctx context.Context, input setEventWorkflowInput) (string, error) { - err := SetEvent(ctx, dbos, WorkflowSetEventInput[string]{Key: input.Key, Message: input.Message}) +func setEventWorkflow(ctx DBOSContext, input setEventWorkflowInput) (string, error) { + err := SetEvent(ctx, WorkflowSetEventInputGeneric[string]{Key: input.Key, Message: input.Message}) if err != nil { return "", err } return "event-set", nil } -func getEventWorkflow(ctx context.Context, input setEventWorkflowInput) (string, error) { - result, err := GetEvent[string](ctx, dbos, WorkflowGetEventInput{ +func getEventWorkflow(ctx DBOSContext, input setEventWorkflowInput) (string, error) { + result, err := GetEvent[string](ctx, WorkflowGetEventInput{ TargetWorkflowID: input.Key, // Reusing Key field as target workflow ID Key: input.Message, // Reusing Message field as event key Timeout: 3 * time.Second, @@ -1486,9 +1502,9 @@ func getEventWorkflow(ctx context.Context, input setEventWorkflowInput) (string, return result, nil } -func setTwoEventsWorkflow(ctx context.Context, input setEventWorkflowInput) (string, error) { +func setTwoEventsWorkflow(ctx DBOSContext, input setEventWorkflowInput) (string, error) { // Set the first event - err := SetEvent(ctx, dbos, WorkflowSetEventInput[string]{Key: "event1", Message: "first-event-message"}) + err := SetEvent(ctx, WorkflowSetEventInputGeneric[string]{Key: "event1", Message: "first-event-message"}) if err != nil { return "", err } @@ -1497,7 +1513,7 @@ func setTwoEventsWorkflow(ctx context.Context, input setEventWorkflowInput) (str setSecondEventSignal.Wait() // Set the second event - err = SetEvent(ctx, dbos, WorkflowSetEventInput[string]{Key: "event2", Message: "second-event-message"}) + err = SetEvent(ctx, WorkflowSetEventInputGeneric[string]{Key: "event2", Message: "second-event-message"}) if err != nil { return "", err } @@ -1505,8 +1521,8 @@ func setTwoEventsWorkflow(ctx context.Context, input setEventWorkflowInput) (str return "two-events-set", nil } -func setEventIdempotencyWorkflow(ctx context.Context, input setEventWorkflowInput) (string, error) { - err := SetEvent(ctx, dbos, WorkflowSetEventInput[string]{Key: input.Key, Message: input.Message}) +func setEventIdempotencyWorkflow(ctx DBOSContext, input setEventWorkflowInput) (string, error) { + err := SetEvent(ctx, WorkflowSetEventInputGeneric[string]{Key: input.Key, Message: input.Message}) if err != nil { return "", err } @@ -1514,8 +1530,8 @@ func setEventIdempotencyWorkflow(ctx context.Context, input setEventWorkflowInpu return "idempotent-set-completed", nil } -func getEventIdempotencyWorkflow(ctx context.Context, input setEventWorkflowInput) (string, error) { - result, err := GetEvent[string](ctx, dbos, WorkflowGetEventInput{ +func getEventIdempotencyWorkflow(ctx DBOSContext, input setEventWorkflowInput) (string, error) { + result, err := GetEvent[string](ctx, WorkflowGetEventInput{ TargetWorkflowID: input.Key, Key: input.Message, Timeout: 3 * time.Second, @@ -1532,18 +1548,18 @@ func TestSetGetEvent(t *testing.T) { executor := setupDBOS(t) // Register all set/get event workflows with executor - setEventWf := RegisterWorkflow(executor, setEventWorkflow) - getEventWf := RegisterWorkflow(executor, getEventWorkflow) - setTwoEventsWf := RegisterWorkflow(executor, setTwoEventsWorkflow) - setEventIdempotencyWf := RegisterWorkflow(executor, setEventIdempotencyWorkflow) - getEventIdempotencyWf := RegisterWorkflow(executor, getEventIdempotencyWorkflow) + RegisterWorkflow(executor, setEventWorkflow) + RegisterWorkflow(executor, getEventWorkflow) + RegisterWorkflow(executor, setTwoEventsWorkflow) + RegisterWorkflow(executor, setEventIdempotencyWorkflow) + RegisterWorkflow(executor, getEventIdempotencyWorkflow) t.Run("SetGetEventFromWorkflow", func(t *testing.T) { // Clear the signal event before starting setSecondEventSignal.Clear() // Start the workflow that sets two events - setHandle, err := setTwoEventsWf(context.Background(), setEventWorkflowInput{ + setHandle, err := RunAsWorkflow(executor, setTwoEventsWorkflow, setEventWorkflowInput{ Key: "test-workflow", Message: "unused", }) @@ -1552,7 +1568,7 @@ func TestSetGetEvent(t *testing.T) { } // Start a workflow to get the first event - getFirstEventHandle, err := getEventWf(context.Background(), setEventWorkflowInput{ + getFirstEventHandle, err := RunAsWorkflow(executor, getEventWorkflow, setEventWorkflowInput{ Key: setHandle.GetWorkflowID(), // Target workflow ID Message: "event1", // Event key }) @@ -1561,7 +1577,7 @@ func TestSetGetEvent(t *testing.T) { } // Verify we can get the first event - firstMessage, err := getFirstEventHandle.GetResult(context.Background()) + firstMessage, err := getFirstEventHandle.GetResult() if err != nil { t.Fatalf("failed to get result from first event workflow: %v", err) } @@ -1573,7 +1589,7 @@ func TestSetGetEvent(t *testing.T) { setSecondEventSignal.Set() // Start a workflow to get the second event - getSecondEventHandle, err := getEventWf(context.Background(), setEventWorkflowInput{ + getSecondEventHandle, err := RunAsWorkflow(executor, getEventWorkflow, setEventWorkflowInput{ Key: setHandle.GetWorkflowID(), // Target workflow ID Message: "event2", // Event key }) @@ -1582,7 +1598,7 @@ func TestSetGetEvent(t *testing.T) { } // Verify we can get the second event - secondMessage, err := getSecondEventHandle.GetResult(context.Background()) + secondMessage, err := getSecondEventHandle.GetResult() if err != nil { t.Fatalf("failed to get result from second event workflow: %v", err) } @@ -1591,7 +1607,7 @@ func TestSetGetEvent(t *testing.T) { } // Wait for the workflow to complete - result, err := setHandle.GetResult(context.Background()) + result, err := setHandle.GetResult() if err != nil { t.Fatalf("failed to get result from set two events workflow: %v", err) } @@ -1602,7 +1618,7 @@ func TestSetGetEvent(t *testing.T) { t.Run("GetEventFromOutsideWorkflow", func(t *testing.T) { // Start a workflow that sets an event - setHandle, err := setEventWf(context.Background(), setEventWorkflowInput{ + setHandle, err := RunAsWorkflow(executor, setEventWorkflow, setEventWorkflowInput{ Key: "test-key", Message: "test-message", }) @@ -1611,13 +1627,13 @@ func TestSetGetEvent(t *testing.T) { } // Wait for the event to be set - _, err = setHandle.GetResult(context.Background()) + _, err = setHandle.GetResult() if err != nil { t.Fatalf("failed to get result from set event workflow: %v", err) } // Start a workflow that gets the event from outside the original workflow - message, err := GetEvent[string](context.Background(), dbos, WorkflowGetEventInput{ + message, err := GetEvent[string](executor, WorkflowGetEventInput{ TargetWorkflowID: setHandle.GetWorkflowID(), Key: "test-key", Timeout: 3 * time.Second, @@ -1633,7 +1649,7 @@ func TestSetGetEvent(t *testing.T) { t.Run("GetEventTimeout", func(t *testing.T) { // Try to get an event from a non-existent workflow nonExistentID := uuid.NewString() - message, err := GetEvent[string](context.Background(), dbos, WorkflowGetEventInput{ + message, err := GetEvent[string](executor, WorkflowGetEventInput{ TargetWorkflowID: nonExistentID, Key: "test-key", Timeout: 3 * time.Second, @@ -1646,18 +1662,18 @@ func TestSetGetEvent(t *testing.T) { } // Try to get an event from an existing workflow but with a key that doesn't exist - setHandle, err := setEventWf(context.Background(), setEventWorkflowInput{ + setHandle, err := RunAsWorkflow(executor, setEventWorkflow, setEventWorkflowInput{ Key: "test-key", Message: "test-message", }) if err != nil { t.Fatal("failed to set event:", err) } - _, err = setHandle.GetResult(context.Background()) + _, err = setHandle.GetResult() if err != nil { t.Fatal("failed to get result from set event workflow:", err) } - message, err = GetEvent[string](context.Background(), dbos, WorkflowGetEventInput{ + message, err = GetEvent[string](executor, WorkflowGetEventInput{ TargetWorkflowID: setHandle.GetWorkflowID(), Key: "non-existent-key", Timeout: 3 * time.Second, @@ -1671,10 +1687,8 @@ func TestSetGetEvent(t *testing.T) { }) t.Run("SetGetEventMustRunInsideWorkflows", func(t *testing.T) { - ctx := context.Background() - // Attempt to run SetEvent outside of a workflow context - err := SetEvent(ctx, dbos, WorkflowSetEventInput[string]{Key: "test-key", Message: "test-message"}) + err := SetEvent(executor, WorkflowSetEventInputGeneric[string]{Key: "test-key", Message: "test-message"}) if err == nil { t.Fatal("expected error when running SetEvent outside of workflow context, but got none") } @@ -1698,7 +1712,7 @@ func TestSetGetEvent(t *testing.T) { t.Run("SetGetEventIdempotency", func(t *testing.T) { // Start the set event workflow - setHandle, err := setEventIdempotencyWf(context.Background(), setEventWorkflowInput{ + setHandle, err := RunAsWorkflow(executor, setEventIdempotencyWorkflow, setEventWorkflowInput{ Key: "idempotency-key", Message: "idempotency-message", }) @@ -1707,7 +1721,7 @@ func TestSetGetEvent(t *testing.T) { } // Start the get event workflow - getHandle, err := getEventIdempotencyWf(context.Background(), setEventWorkflowInput{ + getHandle, err := RunAsWorkflow(executor, getEventIdempotencyWorkflow, setEventWorkflowInput{ Key: setHandle.GetWorkflowID(), Message: "idempotency-key", }) @@ -1720,7 +1734,7 @@ func TestSetGetEvent(t *testing.T) { getEventStartIdempotencyEvent.Clear() // Attempt recovering both workflows. Each should have exactly 1 step. - recoveredHandles, err := recoverPendingWorkflows(context.Background(), []string{"local"}) + recoveredHandles, err := recoverPendingWorkflows(executor.(*dbosContext), []string{"local"}) if err != nil { t.Fatalf("failed to recover pending workflows: %v", err) } @@ -1731,7 +1745,7 @@ func TestSetGetEvent(t *testing.T) { getEventStartIdempotencyEvent.Wait() // Verify step counts - setSteps, err := dbos.systemDB.GetWorkflowSteps(context.Background(), setHandle.GetWorkflowID()) + setSteps, err := executor.(*dbosContext).systemDB.GetWorkflowSteps(context.Background(), setHandle.GetWorkflowID()) if err != nil { t.Fatalf("failed to get steps for set event idempotency workflow: %v", err) } @@ -1739,7 +1753,7 @@ func TestSetGetEvent(t *testing.T) { t.Fatalf("expected 1 step in set event idempotency workflow, got %d", len(setSteps)) } - getSteps, err := dbos.systemDB.GetWorkflowSteps(context.Background(), getHandle.GetWorkflowID()) + getSteps, err := executor.(*dbosContext).systemDB.GetWorkflowSteps(context.Background(), getHandle.GetWorkflowID()) if err != nil { t.Fatalf("failed to get steps for get event idempotency workflow: %v", err) } @@ -1751,7 +1765,7 @@ func TestSetGetEvent(t *testing.T) { setEventIdempotencyEvent.Set() getEventStopIdempotencyEvent.Set() - setResult, err := setHandle.GetResult(context.Background()) + setResult, err := setHandle.GetResult() if err != nil { t.Fatalf("failed to get result from set event idempotency workflow: %v", err) } @@ -1759,7 +1773,7 @@ func TestSetGetEvent(t *testing.T) { t.Fatalf("expected result to be 'idempotent-set-completed', got '%s'", setResult) } - getResult, err := getHandle.GetResult(context.Background()) + getResult, err := getHandle.GetResult() if err != nil { t.Fatalf("failed to get result from get event idempotency workflow: %v", err) } @@ -1770,7 +1784,7 @@ func TestSetGetEvent(t *testing.T) { t.Run("ConcurrentGetEvent", func(t *testing.T) { // Set event - setHandle, err := setEventWf(context.Background(), setEventWorkflowInput{ + setHandle, err := RunAsWorkflow(executor, setEventWorkflow, setEventWorkflowInput{ Key: "concurrent-event-key", Message: "concurrent-event-message", }) @@ -1779,7 +1793,7 @@ func TestSetGetEvent(t *testing.T) { } // Wait for the set event workflow to complete - _, err = setHandle.GetResult(context.Background()) + _, err = setHandle.GetResult() if err != nil { t.Fatalf("failed to get result from set event workflow: %v", err) } @@ -1791,7 +1805,7 @@ func TestSetGetEvent(t *testing.T) { for range numGoroutines { go func() { defer wg.Done() - res, err := GetEvent[string](context.Background(), dbos, WorkflowGetEventInput{ + res, err := GetEvent[string](executor, WorkflowGetEventInput{ TargetWorkflowID: setHandle.GetWorkflowID(), Key: "concurrent-event-key", Timeout: 10 * time.Second, @@ -1821,8 +1835,20 @@ var ( sleepStopEvent *Event ) -func sleepRecoveryWorkflow(ctx context.Context, duration time.Duration) (time.Duration, error) { - result, err := Sleep(ctx, dbos, duration) +func sleepStep(ctx context.Context, input ...any) (time.Duration, error) { + if len(input) == 0 { + return 0, fmt.Errorf("expected duration") + } + duration, ok := input[0].(time.Duration) + if !ok { + return 0, fmt.Errorf("expected time.Duration, got %T", input[0]) + } + time.Sleep(duration) + return duration, nil +} + +func sleepRecoveryWorkflow(ctx DBOSContext, duration time.Duration) (time.Duration, error) { + result, err := RunAsStep(ctx, sleepStep, duration) if err != nil { return 0, err } @@ -1834,7 +1860,7 @@ func sleepRecoveryWorkflow(ctx context.Context, duration time.Duration) (time.Du func TestSleep(t *testing.T) { executor := setupDBOS(t) - sleepRecoveryWf := RegisterWorkflow(executor, sleepRecoveryWorkflow) + RegisterWorkflow(executor, sleepRecoveryWorkflow) t.Run("SleepDurableRecovery", func(t *testing.T) { sleepStartEvent = NewEvent() @@ -1843,7 +1869,7 @@ func TestSleep(t *testing.T) { // Start a workflow that sleeps for 2 seconds then blocks sleepDuration := 2 * time.Second - handle, err := sleepRecoveryWf(context.Background(), sleepDuration) + handle, err := RunAsWorkflow(executor, sleepRecoveryWorkflow, sleepDuration) if err != nil { t.Fatalf("failed to start sleep recovery workflow: %v", err) } @@ -1853,7 +1879,7 @@ func TestSleep(t *testing.T) { // Run the workflow again and check the return time was less than the durable sleep startTime := time.Now() - _, err = sleepRecoveryWf(context.Background(), sleepDuration, WithWorkflowID(handle.GetWorkflowID())) + _, err = RunAsWorkflow(executor, sleepRecoveryWorkflow, sleepDuration, WithWorkflowID(handle.GetWorkflowID())) if err != nil { t.Fatalf("failed to start second sleep recovery workflow: %v", err) } @@ -1866,7 +1892,7 @@ func TestSleep(t *testing.T) { } // Verify the sleep step was recorded correctly - steps, err := dbos.systemDB.GetWorkflowSteps(context.Background(), handle.GetWorkflowID()) + steps, err := executor.(*dbosContext).systemDB.GetWorkflowSteps(context.Background(), handle.GetWorkflowID()) if err != nil { t.Fatalf("failed to get workflow steps: %v", err) } @@ -1876,8 +1902,8 @@ func TestSleep(t *testing.T) { } step := steps[0] - if step.FunctionName != "DBOS.sleep" { - t.Fatalf("expected step name to be 'DBOS.sleep', got '%s'", step.FunctionName) + if step.FunctionName != "dbos.sleepStep" { + t.Fatalf("expected step name to be 'dbos.sleepStep', got '%s'", step.FunctionName) } if step.Error != nil { @@ -1888,10 +1914,8 @@ func TestSleep(t *testing.T) { }) t.Run("SleepCannotBeCalledOutsideWorkflow", func(t *testing.T) { - ctx := context.Background() - // Attempt to call Sleep outside of a workflow context - _, err := Sleep(ctx, dbos, 1*time.Second) + _, err := RunAsStep(executor, sleepStep, 1*time.Second) if err == nil { t.Fatal("expected error when calling Sleep outside of workflow context, but got none") } From 6c332ce65d116d9b4035a2138431b67360ab1efe Mon Sep 17 00:00:00 2001 From: maxdml Date: Fri, 1 Aug 2025 12:51:29 -0700 Subject: [PATCH 11/30] shutdown logs + check if launched --- dbos/dbos.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/dbos/dbos.go b/dbos/dbos.go index d0f7847a..f01b63d5 100644 --- a/dbos/dbos.go +++ b/dbos/dbos.go @@ -276,11 +276,18 @@ func (c *dbosContext) Launch() error { } func (c *dbosContext) Shutdown() { + if !c.launched { + logger.Warn("DBOS is not launched, nothing to shutdown") + return + } + // Wait for all workflows to finish + getLogger().Info("Waiting for all workflows to finish") c.workflowsWg.Wait() // Cancel the context to stop the queue runner if c.queueRunnerCancelFunc != nil { + getLogger().Info("Stopping queue runner") c.queueRunnerCancelFunc() // Wait for queue runner to finish <-c.queueRunnerDone @@ -288,6 +295,7 @@ func (c *dbosContext) Shutdown() { } if c.workflowScheduler != nil { + getLogger().Info("Stopping workflow scheduler") ctx := c.workflowScheduler.Stop() // Wait for all running jobs to complete with 5-second timeout timeoutCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) @@ -302,11 +310,13 @@ func (c *dbosContext) Shutdown() { } if c.systemDB != nil { + getLogger().Info("Shutting down system database") c.systemDB.Shutdown() c.systemDB = nil } if c.adminServer != nil { + getLogger().Info("Shutting down admin server") err := c.adminServer.Shutdown() if err != nil { getLogger().Error("Failed to shutdown admin server", "error", err) From d0f4272588c6ecdfccd6e955421edae1b023c8a4 Mon Sep 17 00:00:00 2001 From: maxdml Date: Fri, 1 Aug 2025 12:51:51 -0700 Subject: [PATCH 12/30] cleanup --- dbos/recovery.go | 9 --------- 1 file changed, 9 deletions(-) diff --git a/dbos/recovery.go b/dbos/recovery.go index e5497d6e..bfa74d7a 100644 --- a/dbos/recovery.go +++ b/dbos/recovery.go @@ -24,7 +24,6 @@ func recoverPendingWorkflows(dbosCtx *dbosContext, executorIDs []string) ([]Work } } - // fmt.Println("Recovering workflow:", workflow.ID, "Name:", workflow.Name, "Input:", workflow.Input, "QueueName:", workflow.QueueName) if workflow.QueueName != "" { cleared, err := dbosCtx.systemDB.ClearQueueAssignment(dbosCtx.ctx, workflow.ID) if err != nil { @@ -47,14 +46,6 @@ func recoverPendingWorkflows(dbosCtx *dbosContext, executorIDs []string) ([]Work opts := []WorkflowOption{ WithWorkflowID(workflow.ID), } - // XXX we'll figure out the exact timeout/deadline settings later - if workflow.Timeout != 0 { - opts = append(opts, WithTimeout(workflow.Timeout)) - } - if !workflow.Deadline.IsZero() { - opts = append(opts, WithDeadline(workflow.Deadline)) - } - // Create a workflow context from the executor context handle, err := registeredWorkflow.wrappedFunction(dbosCtx, workflow.Input, opts...) if err != nil { From c996d7a79fb2787f356a78c2183053832d3cf99c Mon Sep 17 00:00:00 2001 From: maxdml Date: Fri, 1 Aug 2025 12:56:32 -0700 Subject: [PATCH 13/30] fix step output handling + retrieval of registration-time options + call interface method directly during recovery / scheduled wfs --- dbos/workflow.go | 71 ++++++++++++++++++++++++++++++++++-------------- 1 file changed, 51 insertions(+), 20 deletions(-) diff --git a/dbos/workflow.go b/dbos/workflow.go index 8989ced7..9b25e4a6 100644 --- a/dbos/workflow.go +++ b/dbos/workflow.go @@ -144,6 +144,7 @@ type workflowPollingHandle[R any] struct { } func (h *workflowPollingHandle[R]) GetResult() (R, error) { + // FIXME this should use a context available to the user, so they can cancel it instead of infinite waiting ctx := context.Background() result, err := h.dbosContext.(*dbosContext).systemDB.AwaitWorkflowResult(h.dbosContext.(*dbosContext).ctx, h.workflowID) if result != nil { @@ -229,7 +230,7 @@ func registerWorkflow(dbosCtx DBOSContext, workflowName string, fn WrappedWorkfl } } -func registerScheduledWorkflow(dbosCtx DBOSContext, workflowName string, fn WrappedWorkflowFunc, cronSchedule string, maxRetries int) { +func registerScheduledWorkflow(dbosCtx DBOSContext, workflowName string, fn WorkflowFunc, cronSchedule string) { // Skip if we don't have a concrete dbosContext c, ok := dbosCtx.(*dbosContext) if !ok { @@ -251,7 +252,12 @@ func registerScheduledWorkflow(dbosCtx DBOSContext, workflowName string, fn Wrap scheduledTime = entry.Next } wfID := fmt.Sprintf("sched-%s-%s", workflowName, scheduledTime) // XXX we can rethink the format - fn(c, scheduledTime, WithWorkflowID(wfID), WithQueue(_DBOS_INTERNAL_QUEUE_NAME)) + opts := []WorkflowOption{ + WithWorkflowID(wfID), + WithQueue(_DBOS_INTERNAL_QUEUE_NAME), + withWorkflowName(workflowName), + } + dbosCtx.RunAsWorkflow(dbosCtx, fn, scheduledTime, opts...) }) if err != nil { panic(fmt.Sprintf("failed to register scheduled workflow: %v", err)) @@ -262,7 +268,6 @@ func registerScheduledWorkflow(dbosCtx DBOSContext, workflowName string, fn Wrap type workflowRegistrationParams struct { cronSchedule string maxRetries int - workflowName string } type workflowRegistrationOption func(*workflowRegistrationParams) @@ -283,12 +288,6 @@ func WithSchedule(schedule string) workflowRegistrationOption { } } -func WithWorkflowName(name string) workflowRegistrationOption { - return func(p *workflowRegistrationParams) { - p.workflowName = name - } -} - // RegisterWorkflow registers the provided function as a durable workflow with the provided DBOSContext workflow registry // If the workflow is a scheduled workflow (determined by the presence of a cron schedule), it will also register a cron job to execute it // RegisterWorkflow is generically typed, allowing us to register the workflow input and output types for gob encoding @@ -299,18 +298,19 @@ func RegisterWorkflow[P any, R any](dbosCtx DBOSContext, fn GenericWorkflowFunc[ panic("dbosCtx cannot be nil") } + if fn == nil { + panic("workflow function cannot be nil") + } + registrationParams := workflowRegistrationParams{ - maxRetries: _DEFAULT_MAX_RECOVERY_ATTEMPTS, - workflowName: runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name(), + maxRetries: _DEFAULT_MAX_RECOVERY_ATTEMPTS, } for _, opt := range opts { opt(®istrationParams) } - if fn == nil { - panic("workflow function cannot be nil") - } + fqn := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name() // Registry the input/output types for gob encoding var p P @@ -319,27 +319,31 @@ func RegisterWorkflow[P any, R any](dbosCtx DBOSContext, fn GenericWorkflowFunc[ gob.Register(r) // Register a type-erased version of the durable workflow for recovery + typedErasedWorkflow := WorkflowFunc(func(ctx DBOSContext, input any) (any, error) { + return fn(ctx, input.(P)) + }) + typeErasedWrapper := WrappedWorkflowFunc(func(ctx DBOSContext, input any, opts ...WorkflowOption) (WorkflowHandle[any], error) { typedInput, ok := input.(P) if !ok { - return nil, newWorkflowUnexpectedInputType(registrationParams.workflowName, fmt.Sprintf("%T", typedInput), fmt.Sprintf("%T", input)) + return nil, newWorkflowUnexpectedInputType(fqn, fmt.Sprintf("%T", typedInput), fmt.Sprintf("%T", input)) } - opts = append(opts, WithWorkflowMaxRetries(registrationParams.maxRetries)) - handle, err := RunAsWorkflow(ctx, fn, typedInput, opts...) + opts = append(opts, withWorkflowName(fqn)) // Append the name so dbosCtx.RunAsWorkflow can look it up from the registry to apply registration-time options + handle, err := dbosCtx.RunAsWorkflow(ctx, typedErasedWorkflow, typedInput, opts...) if err != nil { return nil, err } return &workflowPollingHandle[any]{workflowID: handle.GetWorkflowID(), dbosContext: ctx}, nil }) - registerWorkflow(dbosCtx, registrationParams.workflowName, typeErasedWrapper, registrationParams.maxRetries) + registerWorkflow(dbosCtx, fqn, typeErasedWrapper, registrationParams.maxRetries) // If this is a scheduled workflow, register a cron job if registrationParams.cronSchedule != "" { if reflect.TypeOf(p) != reflect.TypeOf(time.Time{}) { panic(fmt.Sprintf("scheduled workflow function must accept a time.Time as input, got %T", p)) } - registerScheduledWorkflow(dbosCtx, registrationParams.workflowName, typeErasedWrapper, registrationParams.cronSchedule, registrationParams.maxRetries) + registerScheduledWorkflow(dbosCtx, fqn, typedErasedWorkflow, registrationParams.cronSchedule) } } @@ -402,7 +406,20 @@ func WithWorkflowMaxRetries(maxRetries int) WorkflowOption { } } +func withWorkflowName(name string) WorkflowOption { + return func(p *workflowParams) { + p.workflowName = name + } +} + func RunAsWorkflow[P any, R any](dbosCtx DBOSContext, fn GenericWorkflowFunc[P, R], input P, opts ...WorkflowOption) (WorkflowHandle[R], error) { + if dbosCtx == nil { + return nil, fmt.Errorf("dbosCtx cannot be nil") + } + + // Add the fn name to the options so we can communicate it with DBOSContext.RunAsWorkflow + opts = append(opts, withWorkflowName(runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name())) + typedErasedWorkflow := WorkflowFunc(func(ctx DBOSContext, input any) (any, error) { return fn(ctx, input.(P)) }) @@ -468,9 +485,19 @@ func (c *dbosContext) RunAsWorkflow(dbosCtx DBOSContext, fn WorkflowFunc, input opt(¶ms) } + // Lookup the registry for registration-time options + registeredWorkflow, exists := dbosCtx.(*dbosContext).workflowRegistry[params.workflowName] + if !exists { + return nil, newNonExistentWorkflowError(params.workflowName) + } + if registeredWorkflow.maxRetries > 0 { + params.maxRetries = registeredWorkflow.maxRetries + } + // Check if we are within a workflow (and thus a child workflow) parentWorkflowState, ok := dbosCtx.Value(workflowStateKey).(*workflowState) isChildWorkflow := ok && parentWorkflowState != nil + // TODO Check if cancelled // Generate an ID for the workflow if not provided @@ -674,7 +701,11 @@ func RunAsStep[R any](dbosCtx DBOSContext, fn GenericStepFunc[R], input ...any) // Call the executor method result, err := dbosCtx.RunAsStep(dbosCtx, typeErasedFn, input...) if err != nil { - return *new(R), err + // In case the errors comes from the DBOS step logic, the result will be nil and we must handle it + if result == nil { + return *new(R), err + } + return result.(R), err } // Type-check and cast the result From 4a3fcd21f273788e628e6be1374d2063fd977848 Mon Sep 17 00:00:00 2001 From: maxdml Date: Fri, 1 Aug 2025 12:56:53 -0700 Subject: [PATCH 14/30] fix qrunner context for now --- dbos/queue.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dbos/queue.go b/dbos/queue.go index 070cbd8a..6960a271 100644 --- a/dbos/queue.go +++ b/dbos/queue.go @@ -118,7 +118,7 @@ func queueRunner(executor *dbosContext) { for queueName, queue := range workflowQueueRegistry { getLogger().Debug("Processing queue", "queue_name", queueName) // Call DequeueWorkflows for each queue - dequeuedWorkflows, err := executor.systemDB.DequeueWorkflows(executor.ctx, queue, executor.executorID, executor.applicationVersion) + dequeuedWorkflows, err := executor.systemDB.DequeueWorkflows(executor.queueRunnerCtx, queue, executor.executorID, executor.applicationVersion) if err != nil { if pgErr, ok := err.(*pgconn.PgError); ok { switch pgErr.Code { @@ -183,7 +183,7 @@ func queueRunner(executor *dbosContext) { // Sleep with jittered interval, but allow early exit on context cancellation select { - case <-executor.ctx.Done(): + case <-executor.queueRunnerCtx.Done(): getLogger().Info("Queue runner stopping due to context cancellation") return case <-time.After(sleepDuration): From fb97d3d7b18fc0e49f4067b419fefdf3611b7e12 Mon Sep 17 00:00:00 2001 From: maxdml Date: Fri, 1 Aug 2025 13:18:51 -0700 Subject: [PATCH 15/30] update tests --- dbos/admin_server_test.go | 26 +++++------ dbos/dbos_test.go | 2 +- dbos/initialize_test.go | 82 ++++++++++++++++++++++++++++++++++ dbos/logger_test.go | 2 +- dbos/queues_test.go | 37 +++++++++++++--- dbos/serialization_test.go | 3 ++ dbos/utils_test.go | 20 ++++----- dbos/workflows_test.go | 90 +++++++++++++++++++++----------------- 8 files changed, 192 insertions(+), 70 deletions(-) create mode 100644 dbos/initialize_test.go diff --git a/dbos/admin_server_test.go b/dbos/admin_server_test.go index 1d9fb117..d7eb2a2a 100644 --- a/dbos/admin_server_test.go +++ b/dbos/admin_server_test.go @@ -11,26 +11,26 @@ import ( ) func TestAdminServer(t *testing.T) { - databaseURL := getDatabaseURL(t) + databaseURL := getDatabaseURL() t.Run("Admin server is not started by default", func(t *testing.T) { - executor, err := NewDBOSContext(Config{ + ctx, err := NewDBOSContext(Config{ DatabaseURL: databaseURL, AppName: "test-app", }) if err != nil { t.Skipf("Failed to initialize DBOS: %v", err) } - err = executor.Launch() + err = ctx.Launch() if err != nil { t.Skipf("Failed to initialize DBOS: %v", err) } // Ensure cleanup defer func() { - if executor != nil { - executor.Shutdown() + if ctx != nil { + ctx.Shutdown() } }() @@ -45,11 +45,11 @@ func TestAdminServer(t *testing.T) { } // Verify the DBOS executor doesn't have an admin server instance - if executor == nil { + if ctx == nil { t.Fatal("Expected DBOS instance to be created") } - exec := executor.(*dbosContext) + exec := ctx.(*dbosContext) if exec.adminServer != nil { t.Error("Expected admin server to be nil when not configured") } @@ -60,7 +60,7 @@ func TestAdminServer(t *testing.T) { // (This will be handled by the individual executor cleanup) // Launch DBOS with admin server once for all endpoint tests - executor, err := NewDBOSContext(Config{ + ctx, err := NewDBOSContext(Config{ DatabaseURL: databaseURL, AppName: "test-app", AdminServer: true, @@ -68,15 +68,15 @@ func TestAdminServer(t *testing.T) { if err != nil { t.Skipf("Failed to initialize DBOS with admin server: %v", err) } - err = executor.Launch() + err = ctx.Launch() if err != nil { t.Skipf("Failed to initialize DBOS with admin server: %v", err) } // Ensure cleanup defer func() { - if executor != nil { - executor.Shutdown() + if ctx != nil { + ctx.Shutdown() } }() @@ -84,11 +84,11 @@ func TestAdminServer(t *testing.T) { time.Sleep(100 * time.Millisecond) // Verify the DBOS executor has an admin server instance - if executor == nil { + if ctx == nil { t.Fatal("Expected DBOS instance to be created") } - exec := executor.(*dbosContext) + exec := ctx.(*dbosContext) if exec.adminServer == nil { t.Fatal("Expected admin server to be created in DBOS instance") } diff --git a/dbos/dbos_test.go b/dbos/dbos_test.go index 370c1c51..10de8b14 100644 --- a/dbos/dbos_test.go +++ b/dbos/dbos_test.go @@ -5,7 +5,7 @@ import ( ) func TestConfigValidationErrorTypes(t *testing.T) { - databaseURL := getDatabaseURL(t) + databaseURL := getDatabaseURL() t.Run("FailsWithoutAppName", func(t *testing.T) { config := Config{ diff --git a/dbos/initialize_test.go b/dbos/initialize_test.go new file mode 100644 index 00000000..d15e8cf3 --- /dev/null +++ b/dbos/initialize_test.go @@ -0,0 +1,82 @@ +package dbos + +import ( + "testing" +) + +// TestInitializeReturnsExecutor verifies that our updated Initialize function works correctly +func TestInitializeReturnsExecutor(t *testing.T) { + databaseURL := getDatabaseURL() + + // Test that Initialize returns a DBOSExecutor + ctx, err := NewDBOSContext(Config{ + DatabaseURL: databaseURL, + AppName: "test-initialize", + }) + if err != nil { + t.Fatalf("Failed to initialize DBOS: %v", err) + } + defer func() { + if ctx != nil { + ctx.Shutdown() + } + }() // Clean up executor + + if ctx == nil { + t.Fatal("Initialize returned nil executor") + } + + // Test that executor implements DBOSContext interface + var _ DBOSContext = ctx + + // Test that we can call methods on the executor + appVersion := ctx.GetApplicationVersion() + if appVersion == "" { + t.Fatal("GetApplicationVersion returned empty string") + } + + scheduler := ctx.(*dbosContext).getWorkflowScheduler() + if scheduler == nil { + t.Fatal("getWorkflowScheduler returned nil") + } +} + +// TestWithWorkflowWithExecutor verifies that WithWorkflow works with an executor +func TestWithWorkflowWithExecutor(t *testing.T) { + ctx := setupDBOS(t) + + // Test workflow function + testWorkflow := func(ctx DBOSContext, input string) (string, error) { + return "hello " + input, nil + } + + // Test that RegisterWorkflow works with executor + RegisterWorkflow(ctx, testWorkflow) + + // Test executing the workflow + handle, err := RunAsWorkflow(ctx, testWorkflow, "world") + if err != nil { + t.Fatalf("Failed to execute workflow: %v", err) + } + + result, err := handle.GetResult() + if err != nil { + t.Fatalf("Failed to get workflow result: %v", err) + } + + expected := "hello world" + if result != expected { + t.Fatalf("Expected %q, got %q", expected, result) + } +} + +// TestSetupDBOSReturnsExecutor verifies that setupDBOS returns an executor +func TestSetupDBOSReturnsExecutor(t *testing.T) { + executor := setupDBOS(t) + + if executor == nil { + t.Fatal("setupDBOS returned nil executor") + } + + // Test succeeded - executor is valid +} diff --git a/dbos/logger_test.go b/dbos/logger_test.go index 4b7073b8..72a9acf0 100644 --- a/dbos/logger_test.go +++ b/dbos/logger_test.go @@ -8,7 +8,7 @@ import ( ) func TestLogger(t *testing.T) { - databaseURL := getDatabaseURL(t) + databaseURL := getDatabaseURL() t.Run("Default logger", func(t *testing.T) { dbosCtx, err := NewDBOSContext(Config{ diff --git a/dbos/queues_test.go b/dbos/queues_test.go index 1ea9592d..813632b3 100644 --- a/dbos/queues_test.go +++ b/dbos/queues_test.go @@ -177,12 +177,14 @@ func TestWorkflowQueues(t *testing.T) { } }) + /* TODO: we will move queue registry in the new interface in a subsequent PR t.Run("DynamicRegistration", func(t *testing.T) { q := NewWorkflowQueue("dynamic-queue") if len(q.name) > 0 { t.Fatalf("expected nil queue for dynamic registration after DBOS initialization, got %v", q) } }) + */ t.Run("QueueWorkflowDLQ", func(t *testing.T) { workflowID := "blocking-workflow-test" @@ -303,6 +305,11 @@ func TestQueueRecovery(t *testing.T) { } RegisterWorkflow(dbosCtx, recoveryWorkflowFunc) + err := dbosCtx.Launch() + if err != nil { + t.Fatalf("failed to launch DBOS instance: %v", err) + } + queuedSteps := 5 for i := range recoveryStepEvents { @@ -402,7 +409,7 @@ var ( ) func TestGlobalConcurrency(t *testing.T) { - dbosContext := setupDBOS(t) + dbosCtx := setupDBOS(t) // Create workflow with dbosContext globalConcurrencyWorkflowFunc := func(ctx DBOSContext, input string) (string, error) { @@ -415,15 +422,20 @@ func TestGlobalConcurrency(t *testing.T) { } return input, nil } - RegisterWorkflow(dbosContext, globalConcurrencyWorkflowFunc) + RegisterWorkflow(dbosCtx, globalConcurrencyWorkflowFunc) + + err := dbosCtx.Launch() + if err != nil { + t.Fatalf("failed to launch DBOS instance: %v", err) + } // Enqueue two workflows - handle1, err := RunAsWorkflow(dbosContext, globalConcurrencyWorkflowFunc, "workflow1", WithQueue(globalConcurrencyQueue.name)) + handle1, err := RunAsWorkflow(dbosCtx, globalConcurrencyWorkflowFunc, "workflow1", WithQueue(globalConcurrencyQueue.name)) if err != nil { t.Fatalf("failed to enqueue workflow1: %v", err) } - handle2, err := RunAsWorkflow(dbosContext, globalConcurrencyWorkflowFunc, "workflow2", WithQueue(globalConcurrencyQueue.name)) + handle2, err := RunAsWorkflow(dbosCtx, globalConcurrencyWorkflowFunc, "workflow2", WithQueue(globalConcurrencyQueue.name)) if err != nil { t.Fatalf("failed to enqueue workflow2: %v", err) } @@ -465,7 +477,7 @@ func TestGlobalConcurrency(t *testing.T) { if result2 != "workflow2" { t.Fatalf("expected result from workflow2 to be 'workflow2', got %v", result2) } - if !queueEntriesAreCleanedUp(dbosContext) { + if !queueEntriesAreCleanedUp(dbosCtx) { t.Fatal("expected queue entries to be cleaned up after global concurrency test") } } @@ -498,6 +510,11 @@ func TestWorkerConcurrency(t *testing.T) { } RegisterWorkflow(dbosCtx, blockingWfFunc) + err := dbosCtx.Launch() + if err != nil { + t.Fatalf("failed to launch DBOS instance: %v", err) + } + // First enqueue four blocking workflows handle1, err := RunAsWorkflow(dbosCtx, blockingWfFunc, 0, WithQueue(workerConcurrencyQueue.name), WithWorkflowID("worker-cc-wf-1")) if err != nil { @@ -653,6 +670,11 @@ func TestWorkerConcurrencyXRecovery(t *testing.T) { } RegisterWorkflow(dbosCtx, workerConcurrencyRecoveryBlockingWf2) + err := dbosCtx.Launch() + if err != nil { + t.Fatalf("failed to launch DBOS instance: %v", err) + } + // Enqueue two workflows on a queue with worker concurrency = 1 handle1, err := RunAsWorkflow(dbosCtx, workerConcurrencyRecoveryBlockingWf1, "workflow1", WithQueue(workerConcurrencyRecoveryQueue.name), WithWorkflowID("worker-cc-x-recovery-wf-1")) if err != nil { @@ -760,6 +782,11 @@ func TestQueueRateLimiter(t *testing.T) { // Create workflow with dbosContext RegisterWorkflow(dbosCtx, rateLimiterTestWorkflow) + err := dbosCtx.Launch() + if err != nil { + t.Fatalf("failed to launch DBOS instance: %v", err) + } + limit := 5 period := 1.8 numWaves := 3 diff --git a/dbos/serialization_test.go b/dbos/serialization_test.go index 7ab3a3be..a7129c14 100644 --- a/dbos/serialization_test.go +++ b/dbos/serialization_test.go @@ -20,12 +20,15 @@ import ( // Builtin types func encodingStepBuiltinTypes(_ context.Context, input ...any) (int, error) { if len(input) == 0 { + fmt.Println("No input provided to encodingStepBuiltinTypes") return 0, errors.New("step error") } val, ok := input[0].(int) + fmt.Println("Input to encodingStepBuiltinTypes:", val, "ok:", ok) if !ok { return 0, errors.New("step error") } + fmt.Println("Processing input in encodingStepBuiltinTypes:", val) return val, errors.New("step error") } diff --git a/dbos/utils_test.go b/dbos/utils_test.go index 48bad8e3..fc839d2a 100644 --- a/dbos/utils_test.go +++ b/dbos/utils_test.go @@ -12,7 +12,7 @@ import ( "github.com/jackc/pgx/v5" ) -func getDatabaseURL(t *testing.T) string { +func getDatabaseURL() string { databaseURL := os.Getenv("DBOS_SYSTEM_DATABASE_URL") if databaseURL == "" { password := os.Getenv("PGPASSWORD") @@ -28,7 +28,7 @@ func getDatabaseURL(t *testing.T) string { func setupDBOS(t *testing.T) DBOSContext { t.Helper() - databaseURL := getDatabaseURL(t) + databaseURL := getDatabaseURL() // Clean up the test database parsedURL, err := pgx.ParseConfig(databaseURL) @@ -113,9 +113,9 @@ func (e *Event) Clear() { /* Helpers */ // stopQueueRunner stops the queue runner for testing purposes -func stopQueueRunner(executor DBOSContext) { - if executor != nil { - exec := executor.(*dbosContext) +func stopQueueRunner(ctx DBOSContext) { + if ctx != nil { + exec := ctx.(*dbosContext) if exec.queueRunnerCancelFunc != nil { exec.queueRunnerCancelFunc() // Wait for queue runner to finish @@ -125,9 +125,9 @@ func stopQueueRunner(executor DBOSContext) { } // restartQueueRunner restarts the queue runner for testing purposes -func restartQueueRunner(executor DBOSContext) { - if executor != nil { - exec := executor.(*dbosContext) +func restartQueueRunner(ctx DBOSContext) { + if ctx != nil { + exec := ctx.(*dbosContext) // Create new context and cancel function // FIXME: cancellation now has to go through the DBOSContext ctx, cancel := context.WithCancel(context.Background()) @@ -155,12 +155,12 @@ func equal(a, b []int) bool { return true } -func queueEntriesAreCleanedUp(executor DBOSContext) bool { +func queueEntriesAreCleanedUp(ctx DBOSContext) bool { maxTries := 10 success := false for range maxTries { // Begin transaction - exec := executor.(*dbosContext) + exec := ctx.(*dbosContext) tx, err := exec.systemDB.(*systemDatabase).pool.Begin(context.Background()) if err != nil { return false diff --git a/dbos/workflows_test.go b/dbos/workflows_test.go index 44d2e288..064dbfcc 100644 --- a/dbos/workflows_test.go +++ b/dbos/workflows_test.go @@ -36,6 +36,7 @@ func simpleWorkflowWithStep(dbosCtx DBOSContext, input string) (string, error) { } func simpleStep(ctx context.Context, input ...any) (string, error) { + fmt.Println("simpleStep called with input:", input) return "from step", nil } @@ -724,7 +725,11 @@ var ( func deadLetterQueueWorkflow(ctx DBOSContext, input string) (int, error) { recoveryCount++ - fmt.Printf("Dead letter queue workflow started, recovery count: %d\n", recoveryCount) + wfid, err := ctx.GetWorkflowID() + if err != nil { + return 0, fmt.Errorf("failed to get workflow ID: %v", err) + } + fmt.Printf("Dead letter queue workflow %s started, recovery count: %d\n", wfid, recoveryCount) deadLetterQueueStartEvent.Set() deadLetterQueueEvent.Wait() return 0, nil @@ -736,9 +741,9 @@ func infiniteDeadLetterQueueWorkflow(ctx DBOSContext, input string) (int, error) return 0, nil } func TestWorkflowDeadLetterQueue(t *testing.T) { - executor := setupDBOS(t) - RegisterWorkflow(executor, deadLetterQueueWorkflow, WithMaxRetries(maxRecoveryAttempts)) - RegisterWorkflow(executor, infiniteDeadLetterQueueWorkflow, WithMaxRetries(-1)) // A negative value means infinite retries + dbosCtx := setupDBOS(t) + RegisterWorkflow(dbosCtx, deadLetterQueueWorkflow, WithMaxRetries(maxRecoveryAttempts)) + RegisterWorkflow(dbosCtx, infiniteDeadLetterQueueWorkflow, WithMaxRetries(-1)) // A negative value means infinite retries t.Run("DeadLetterQueueBehavior", func(t *testing.T) { deadLetterQueueEvent = NewEvent() @@ -747,7 +752,8 @@ func TestWorkflowDeadLetterQueue(t *testing.T) { // Start a workflow that blocks forever wfID := uuid.NewString() - handle, err := RunAsWorkflow(executor, deadLetterQueueWorkflow, "test", WithWorkflowID(wfID)) + fmt.Println(wfID) + handle, err := RunAsWorkflow(dbosCtx, deadLetterQueueWorkflow, "test", WithWorkflowID(wfID)) if err != nil { t.Fatalf("failed to start dead letter queue workflow: %v", err) } @@ -756,7 +762,7 @@ func TestWorkflowDeadLetterQueue(t *testing.T) { // Attempt to recover the blocked workflow the maximum number of times for i := range maxRecoveryAttempts { - _, err := recoverPendingWorkflows(executor.(*dbosContext), []string{"local"}) + _, err := recoverPendingWorkflows(dbosCtx.(*dbosContext), []string{"local"}) if err != nil { t.Fatalf("failed to recover pending workflows on attempt %d: %v", i+1, err) } @@ -769,7 +775,7 @@ func TestWorkflowDeadLetterQueue(t *testing.T) { } // Verify an additional attempt throws a DLQ error and puts the workflow in the DLQ status - _, err = recoverPendingWorkflows(executor.(*dbosContext), []string{"local"}) + _, err = recoverPendingWorkflows(dbosCtx.(*dbosContext), []string{"local"}) if err == nil { t.Fatal("expected dead letter queue error but got none") } @@ -792,7 +798,7 @@ func TestWorkflowDeadLetterQueue(t *testing.T) { } // Verify that attempting to start a workflow with the same ID throws a DLQ error - _, err = RunAsWorkflow(executor, deadLetterQueueWorkflow, "test", WithWorkflowID(wfID)) + _, err = RunAsWorkflow(dbosCtx, deadLetterQueueWorkflow, "test", WithWorkflowID(wfID)) if err == nil { t.Fatal("expected dead letter queue error when restarting workflow with same ID but got none") } @@ -861,7 +867,7 @@ func TestWorkflowDeadLetterQueue(t *testing.T) { // Verify that a workflow with MaxRetries=0 (infinite retries) is retried infinitely wfID := uuid.NewString() - handle, err := RunAsWorkflow(executor, infiniteDeadLetterQueueWorkflow, "test", WithWorkflowID(wfID)) + handle, err := RunAsWorkflow(dbosCtx, infiniteDeadLetterQueueWorkflow, "test", WithWorkflowID(wfID)) if err != nil { t.Fatalf("failed to start infinite dead letter queue workflow: %v", err) } @@ -871,7 +877,7 @@ func TestWorkflowDeadLetterQueue(t *testing.T) { // Attempt to recover the blocked workflow many times (should never fail) handles := []WorkflowHandle[any]{} for i := range _DEFAULT_MAX_RECOVERY_ATTEMPTS * 2 { - recoveredHandles, err := recoverPendingWorkflows(executor.(*dbosContext), []string{"local"}) + recoveredHandles, err := recoverPendingWorkflows(dbosCtx.(*dbosContext), []string{"local"}) if err != nil { t.Fatalf("failed to recover pending workflows on attempt %d: %v", i+1, err) } @@ -910,16 +916,25 @@ var ( ) func TestScheduledWorkflows(t *testing.T) { - executor := setupDBOS(t) - RegisterWorkflow(executor, func(ctx DBOSContext, scheduledTime time.Time) (string, error) { + dbosCtx := setupDBOS(t) + RegisterWorkflow(dbosCtx, func(ctx DBOSContext, scheduledTime time.Time) (string, error) { startTime := time.Now() counter++ if counter == 10 { return "", fmt.Errorf("counter reached 10, stopping workflow") } - return fmt.Sprintf("Scheduled workflow executed at %v", startTime), nil + select { + case counter1Ch <- startTime: + default: + } + return fmt.Sprintf("Scheduled workflow scheduled at time %v and executed at time %v", scheduledTime, startTime), nil }, WithSchedule("* * * * * *")) // Every second + err := dbosCtx.Launch() + if err != nil { + t.Fatalf("failed to launch executor: %v", err) + } + // Helper function to collect execution times collectExecutionTimes := func(ch chan time.Time, target int, timeout time.Duration) ([]time.Time, error) { var executionTimes []time.Time @@ -969,7 +984,7 @@ func TestScheduledWorkflows(t *testing.T) { // Stop the workflowScheduler and check if it stops executing currentCounter := counter - executor.(*dbosContext).getWorkflowScheduler().Stop() + dbosCtx.(*dbosContext).getWorkflowScheduler().Stop() time.Sleep(3 * time.Second) // Wait a bit to ensure no more executions if counter >= currentCounter+2 { t.Fatalf("Scheduled workflow continued executing after stopping scheduler: %d (expected < %d)", counter, currentCounter+2) @@ -1079,12 +1094,19 @@ func stepThatCallsSend(ctx context.Context, input ...any) (string, error) { if len(input) == 0 { return "", fmt.Errorf("expected sendWorkflowInput") } - _, ok := input[0].(sendWorkflowInput) + i, ok := input[0].(sendWorkflowInput) if !ok { return "", fmt.Errorf("expected sendWorkflowInput, got %T", input[0]) } - // Note: Send cannot be called from within steps, this should fail - return "", fmt.Errorf("Send cannot be called from within a step") + err := Send(ctx.(DBOSContext), WorkflowSendInput[string]{ + DestinationID: i.DestinationID, + Topic: i.Topic, + Message: "message-from-step", + }) + if err != nil { + return "", err + } + return "send-completed", nil } func workflowThatCallsSendInStep(ctx DBOSContext, input sendWorkflowInput) (string, error) { @@ -1384,7 +1406,7 @@ func TestSendRecv(t *testing.T) { } }) - t.Run("ConcurrentRecv", func(t *testing.T) { + t.Run("TestSendRecv", func(t *testing.T) { // Test concurrent receivers - only 1 should timeout, others should get errors receiveTopic := "concurrent-recv-topic" @@ -1835,20 +1857,8 @@ var ( sleepStopEvent *Event ) -func sleepStep(ctx context.Context, input ...any) (time.Duration, error) { - if len(input) == 0 { - return 0, fmt.Errorf("expected duration") - } - duration, ok := input[0].(time.Duration) - if !ok { - return 0, fmt.Errorf("expected time.Duration, got %T", input[0]) - } - time.Sleep(duration) - return duration, nil -} - -func sleepRecoveryWorkflow(ctx DBOSContext, duration time.Duration) (time.Duration, error) { - result, err := RunAsStep(ctx, sleepStep, duration) +func sleepRecoveryWorkflow(dbosCtx DBOSContext, duration time.Duration) (time.Duration, error) { + result, err := dbosCtx.Sleep(duration) if err != nil { return 0, err } @@ -1859,8 +1869,8 @@ func sleepRecoveryWorkflow(ctx DBOSContext, duration time.Duration) (time.Durati } func TestSleep(t *testing.T) { - executor := setupDBOS(t) - RegisterWorkflow(executor, sleepRecoveryWorkflow) + dbosCtx := setupDBOS(t) + RegisterWorkflow(dbosCtx, sleepRecoveryWorkflow) t.Run("SleepDurableRecovery", func(t *testing.T) { sleepStartEvent = NewEvent() @@ -1869,7 +1879,7 @@ func TestSleep(t *testing.T) { // Start a workflow that sleeps for 2 seconds then blocks sleepDuration := 2 * time.Second - handle, err := RunAsWorkflow(executor, sleepRecoveryWorkflow, sleepDuration) + handle, err := RunAsWorkflow(dbosCtx, sleepRecoveryWorkflow, sleepDuration) if err != nil { t.Fatalf("failed to start sleep recovery workflow: %v", err) } @@ -1879,7 +1889,7 @@ func TestSleep(t *testing.T) { // Run the workflow again and check the return time was less than the durable sleep startTime := time.Now() - _, err = RunAsWorkflow(executor, sleepRecoveryWorkflow, sleepDuration, WithWorkflowID(handle.GetWorkflowID())) + _, err = RunAsWorkflow(dbosCtx, sleepRecoveryWorkflow, sleepDuration, WithWorkflowID(handle.GetWorkflowID())) if err != nil { t.Fatalf("failed to start second sleep recovery workflow: %v", err) } @@ -1892,7 +1902,7 @@ func TestSleep(t *testing.T) { } // Verify the sleep step was recorded correctly - steps, err := executor.(*dbosContext).systemDB.GetWorkflowSteps(context.Background(), handle.GetWorkflowID()) + steps, err := dbosCtx.(*dbosContext).systemDB.GetWorkflowSteps(context.Background(), handle.GetWorkflowID()) if err != nil { t.Fatalf("failed to get workflow steps: %v", err) } @@ -1902,8 +1912,8 @@ func TestSleep(t *testing.T) { } step := steps[0] - if step.FunctionName != "dbos.sleepStep" { - t.Fatalf("expected step name to be 'dbos.sleepStep', got '%s'", step.FunctionName) + if step.FunctionName != "DBOS.sleep" { + t.Fatalf("expected step name to be 'DBOS.sleep', got '%s'", step.FunctionName) } if step.Error != nil { @@ -1915,7 +1925,7 @@ func TestSleep(t *testing.T) { t.Run("SleepCannotBeCalledOutsideWorkflow", func(t *testing.T) { // Attempt to call Sleep outside of a workflow context - _, err := RunAsStep(executor, sleepStep, 1*time.Second) + _, err := dbosCtx.Sleep(1 * time.Second) if err == nil { t.Fatal("expected error when calling Sleep outside of workflow context, but got none") } From c3ea87b02de055e5c31079036a78c7eb515b6738 Mon Sep 17 00:00:00 2001 From: maxdml Date: Fri, 1 Aug 2025 13:30:54 -0700 Subject: [PATCH 16/30] simpler --- dbos/workflow.go | 42 +++++++++++++++++++++--------------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/dbos/workflow.go b/dbos/workflow.go index 9b25e4a6..edc24e1f 100644 --- a/dbos/workflow.go +++ b/dbos/workflow.go @@ -355,7 +355,7 @@ type DBOSContextKey string const workflowStateKey DBOSContextKey = "workflowState" -type GenericWorkflowFunc[P any, R any] func(dbosCtx DBOSContext, input P) (R, error) +type GenericWorkflowFunc[P any, R any] func(ctx DBOSContext, input P) (R, error) type WorkflowFunc func(ctx DBOSContext, input any) (any, error) type workflowParams struct { @@ -476,17 +476,17 @@ func RunAsWorkflow[P any, R any](dbosCtx DBOSContext, fn GenericWorkflowFunc[P, return nil, fmt.Errorf("unexpected workflow handle type: %T", handle) } -func (c *dbosContext) RunAsWorkflow(dbosCtx DBOSContext, fn WorkflowFunc, input any, opts ...WorkflowOption) (WorkflowHandle[any], error) { +func (c *dbosContext) RunAsWorkflow(w_ DBOSContext, fn WorkflowFunc, input any, opts ...WorkflowOption) (WorkflowHandle[any], error) { // Apply options to build params params := workflowParams{ - applicationVersion: dbosCtx.GetApplicationVersion(), + applicationVersion: c.GetApplicationVersion(), } for _, opt := range opts { opt(¶ms) } // Lookup the registry for registration-time options - registeredWorkflow, exists := dbosCtx.(*dbosContext).workflowRegistry[params.workflowName] + registeredWorkflow, exists := c.workflowRegistry[params.workflowName] if !exists { return nil, newNonExistentWorkflowError(params.workflowName) } @@ -495,7 +495,7 @@ func (c *dbosContext) RunAsWorkflow(dbosCtx DBOSContext, fn WorkflowFunc, input } // Check if we are within a workflow (and thus a child workflow) - parentWorkflowState, ok := dbosCtx.Value(workflowStateKey).(*workflowState) + parentWorkflowState, ok := c.Value(workflowStateKey).(*workflowState) isChildWorkflow := ok && parentWorkflowState != nil // TODO Check if cancelled @@ -515,12 +515,12 @@ func (c *dbosContext) RunAsWorkflow(dbosCtx DBOSContext, fn WorkflowFunc, input // If this is a child workflow that has already been recorded in operations_output, return directly a polling handle if isChildWorkflow { - childWorkflowID, err := dbosCtx.(*dbosContext).systemDB.CheckChildWorkflow(dbosCtx.(*dbosContext).ctx, parentWorkflowState.workflowID, parentWorkflowState.stepID) + childWorkflowID, err := c.systemDB.CheckChildWorkflow(c.ctx, parentWorkflowState.workflowID, parentWorkflowState.stepID) if err != nil { return nil, newWorkflowExecutionError(parentWorkflowState.workflowID, fmt.Sprintf("checking child workflow: %v", err)) } if childWorkflowID != nil { - return &workflowPollingHandle[any]{workflowID: *childWorkflowID, dbosContext: dbosCtx}, nil + return &workflowPollingHandle[any]{workflowID: *childWorkflowID, dbosContext: c}, nil } } @@ -534,23 +534,23 @@ func (c *dbosContext) RunAsWorkflow(dbosCtx DBOSContext, fn WorkflowFunc, input workflowStatus := WorkflowStatus{ Name: params.workflowName, ApplicationVersion: params.applicationVersion, - ExecutorID: dbosCtx.GetExecutorID(), + ExecutorID: c.GetExecutorID(), Status: status, ID: workflowID, CreatedAt: time.Now(), Deadline: params.deadline, // TODO compute the deadline based on the timeout Timeout: params.timeout, Input: input, - ApplicationID: dbosCtx.GetApplicationID(), + ApplicationID: c.GetApplicationID(), QueueName: params.queueName, } // Init status and record child workflow relationship in a single transaction - tx, err := dbosCtx.(*dbosContext).systemDB.(*systemDatabase).pool.Begin(dbosCtx.(*dbosContext).ctx) + tx, err := c.systemDB.(*systemDatabase).pool.Begin(c.ctx) if err != nil { return nil, newWorkflowExecutionError(workflowID, fmt.Sprintf("failed to begin transaction: %v", err)) } - defer tx.Rollback(dbosCtx.(*dbosContext).ctx) // Rollback if not committed + defer tx.Rollback(c.ctx) // Rollback if not committed // Insert workflow status with transaction insertInput := insertWorkflowStatusDBInput{ @@ -558,7 +558,7 @@ func (c *dbosContext) RunAsWorkflow(dbosCtx DBOSContext, fn WorkflowFunc, input maxRetries: params.maxRetries, tx: tx, } - insertStatusResult, err := dbosCtx.(*dbosContext).systemDB.InsertWorkflowStatus(dbosCtx.(*dbosContext).ctx, insertInput) + insertStatusResult, err := c.systemDB.InsertWorkflowStatus(c.ctx, insertInput) if err != nil { return nil, err } @@ -566,10 +566,10 @@ func (c *dbosContext) RunAsWorkflow(dbosCtx DBOSContext, fn WorkflowFunc, input // Return a polling handle if: we are enqueueing, the workflow is already in a terminal state (success or error), if len(params.queueName) > 0 || insertStatusResult.status == WorkflowStatusSuccess || insertStatusResult.status == WorkflowStatusError { // Commit the transaction to update the number of attempts and/or enact the enqueue - if err := tx.Commit(dbosCtx.(*dbosContext).ctx); err != nil { + if err := tx.Commit(c.ctx); err != nil { return nil, newWorkflowExecutionError(workflowID, fmt.Sprintf("failed to commit transaction: %v", err)) } - return &workflowPollingHandle[any]{workflowID: workflowStatus.ID, dbosContext: dbosCtx}, nil + return &workflowPollingHandle[any]{workflowID: workflowStatus.ID, dbosContext: c}, nil } // Record child workflow relationship if this is a child workflow @@ -583,7 +583,7 @@ func (c *dbosContext) RunAsWorkflow(dbosCtx DBOSContext, fn WorkflowFunc, input stepID: stepID, tx: tx, } - err = dbosCtx.(*dbosContext).systemDB.RecordChildWorkflow(dbosCtx.(*dbosContext).ctx, childInput) + err = c.systemDB.RecordChildWorkflow(c.ctx, childInput) if err != nil { return nil, newWorkflowExecutionError(parentWorkflowState.workflowID, fmt.Sprintf("recording child workflow: %v", err)) } @@ -601,16 +601,16 @@ func (c *dbosContext) RunAsWorkflow(dbosCtx DBOSContext, fn WorkflowFunc, input } // Run the function in a goroutine - workflowCtx := WithValue(dbosCtx, workflowStateKey, wfState) - dbosCtx.(*dbosContext).workflowsWg.Add(1) + workflowCtx := WithValue(c, workflowStateKey, wfState) + c.workflowsWg.Add(1) go func() { - defer dbosCtx.(*dbosContext).workflowsWg.Done() + defer c.workflowsWg.Done() result, err := fn(workflowCtx, input) status := WorkflowStatusSuccess if err != nil { status = WorkflowStatusError } - recordErr := dbosCtx.(*dbosContext).systemDB.UpdateWorkflowOutcome(dbosCtx.(*dbosContext).ctx, updateWorkflowOutcomeDBInput{ + recordErr := c.systemDB.UpdateWorkflowOutcome(c.ctx, updateWorkflowOutcomeDBInput{ workflowID: workflowID, status: status, err: err, @@ -626,11 +626,11 @@ func (c *dbosContext) RunAsWorkflow(dbosCtx DBOSContext, fn WorkflowFunc, input }() // Commit the transaction - if err := tx.Commit(dbosCtx.(*dbosContext).ctx); err != nil { + if err := tx.Commit(c.ctx); err != nil { return nil, newWorkflowExecutionError(workflowID, fmt.Sprintf("failed to commit transaction: %v", err)) } - return &workflowHandle[any]{workflowID: workflowID, outcomeChan: outcomeChan, dbosContext: dbosCtx}, nil + return &workflowHandle[any]{workflowID: workflowID, outcomeChan: outcomeChan, dbosContext: c}, nil } /******************************/ From 7a5413229a4a5dc734a2478a7a45dec6f8d64718 Mon Sep 17 00:00:00 2001 From: maxdml Date: Fri, 1 Aug 2025 15:17:54 -0700 Subject: [PATCH 17/30] fix bug in WithValue --- dbos/dbos.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dbos/dbos.go b/dbos/dbos.go index f01b63d5..454156b1 100644 --- a/dbos/dbos.go +++ b/dbos/dbos.go @@ -146,6 +146,8 @@ func WithValue(ctx DBOSContext, key, val any) DBOSContext { ctx: context.WithValue(dbosCtx.ctx, key, val), systemDB: dbosCtx.systemDB, workflowsWg: dbosCtx.workflowsWg, + workflowRegistry: dbosCtx.workflowRegistry, + workflowRegMutex: dbosCtx.workflowRegMutex, applicationVersion: dbosCtx.applicationVersion, executorID: dbosCtx.executorID, applicationID: dbosCtx.applicationID, From e4a98066ad2a0542e9700c6739ff4ad1707cbdb2 Mon Sep 17 00:00:00 2001 From: maxdml Date: Fri, 1 Aug 2025 15:18:12 -0700 Subject: [PATCH 18/30] comment --- dbos/queue.go | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/dbos/queue.go b/dbos/queue.go index 6960a271..62042151 100644 --- a/dbos/queue.go +++ b/dbos/queue.go @@ -98,7 +98,7 @@ func NewWorkflowQueue(name string, options ...queueOption) WorkflowQueue { return q } -func queueRunner(executor *dbosContext) { +func queueRunner(ctx *dbosContext) { const ( baseInterval = 1.0 // Base interval in seconds minInterval = 1.0 // Minimum polling interval in seconds @@ -118,7 +118,7 @@ func queueRunner(executor *dbosContext) { for queueName, queue := range workflowQueueRegistry { getLogger().Debug("Processing queue", "queue_name", queueName) // Call DequeueWorkflows for each queue - dequeuedWorkflows, err := executor.systemDB.DequeueWorkflows(executor.queueRunnerCtx, queue, executor.executorID, executor.applicationVersion) + dequeuedWorkflows, err := ctx.systemDB.DequeueWorkflows(ctx.queueRunnerCtx, queue, ctx.executorID, ctx.applicationVersion) if err != nil { if pgErr, ok := err.(*pgconn.PgError); ok { switch pgErr.Code { @@ -139,7 +139,7 @@ func queueRunner(executor *dbosContext) { } for _, workflow := range dequeuedWorkflows { // Find the workflow in the registry - registeredWorkflow, exists := executor.workflowRegistry[workflow.name] + registeredWorkflow, exists := ctx.workflowRegistry[workflow.name] if !exists { getLogger().Error("workflow function not found in registry", "workflow_name", workflow.name) continue @@ -161,7 +161,8 @@ func queueRunner(executor *dbosContext) { } } - _, err := registeredWorkflow.wrappedFunction(executor, input, WithWorkflowID(workflow.id)) + // XXX this demonstrate why contexts cannot be used globally -- the task does not inherit the context used in the program that enqueued it + _, err := registeredWorkflow.wrappedFunction(ctx, input, WithWorkflowID(workflow.id)) if err != nil { getLogger().Error("Error running queued workflow", "error", err) } @@ -183,7 +184,7 @@ func queueRunner(executor *dbosContext) { // Sleep with jittered interval, but allow early exit on context cancellation select { - case <-executor.queueRunnerCtx.Done(): + case <-ctx.queueRunnerCtx.Done(): getLogger().Info("Queue runner stopping due to context cancellation") return case <-time.After(sleepDuration): From ab75aaf11205d04d044380ef85363054dfef2c28 Mon Sep 17 00:00:00 2001 From: maxdml Date: Fri, 1 Aug 2025 15:18:48 -0700 Subject: [PATCH 19/30] identify uncancellable AwaitWorkflowResult -- will have to pass it a cancellable context --- dbos/system_database.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/dbos/system_database.go b/dbos/system_database.go index 05f10c41..9ceb7551 100644 --- a/dbos/system_database.go +++ b/dbos/system_database.go @@ -645,6 +645,12 @@ func (s *systemDatabase) AwaitWorkflowResult(ctx context.Context, workflowID str query := `SELECT status, output, error FROM dbos.workflow_status WHERE workflow_uuid = $1` var status WorkflowStatusType for { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + row := s.pool.QueryRow(ctx, query, workflowID) var outputString *string var errorStr *string @@ -671,7 +677,7 @@ func (s *systemDatabase) AwaitWorkflowResult(ctx context.Context, workflowID str case WorkflowStatusCancelled: return nil, newAwaitedWorkflowCancelledError(workflowID) default: - time.Sleep(1 * time.Second) // Wait before checking again + time.Sleep(1 * time.Second) } } } From 60808e3e30ba2104e035d01cc5582f0340f25af5 Mon Sep 17 00:00:00 2001 From: maxdml Date: Fri, 1 Aug 2025 15:19:53 -0700 Subject: [PATCH 20/30] add defensive type check --- dbos/workflow.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/dbos/workflow.go b/dbos/workflow.go index edc24e1f..db4a246a 100644 --- a/dbos/workflow.go +++ b/dbos/workflow.go @@ -320,7 +320,12 @@ func RegisterWorkflow[P any, R any](dbosCtx DBOSContext, fn GenericWorkflowFunc[ // Register a type-erased version of the durable workflow for recovery typedErasedWorkflow := WorkflowFunc(func(ctx DBOSContext, input any) (any, error) { - return fn(ctx, input.(P)) + // This type check is redundant with the one in the wrapper, but I'd better be safe than sorry + typedInput, ok := input.(P) + if !ok { + return nil, newWorkflowUnexpectedInputType(fqn, fmt.Sprintf("%T", typedInput), fmt.Sprintf("%T", input)) + } + return fn(ctx, typedInput) }) typeErasedWrapper := WrappedWorkflowFunc(func(ctx DBOSContext, input any, opts ...WorkflowOption) (WorkflowHandle[any], error) { From 81c2fde52be65f90cb037ec1eff9e92969284ada Mon Sep 17 00:00:00 2001 From: maxdml Date: Fri, 1 Aug 2025 15:20:12 -0700 Subject: [PATCH 21/30] wrapper should use provided context --- dbos/workflow.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dbos/workflow.go b/dbos/workflow.go index db4a246a..5c22ad75 100644 --- a/dbos/workflow.go +++ b/dbos/workflow.go @@ -335,7 +335,7 @@ func RegisterWorkflow[P any, R any](dbosCtx DBOSContext, fn GenericWorkflowFunc[ } opts = append(opts, withWorkflowName(fqn)) // Append the name so dbosCtx.RunAsWorkflow can look it up from the registry to apply registration-time options - handle, err := dbosCtx.RunAsWorkflow(ctx, typedErasedWorkflow, typedInput, opts...) + handle, err := ctx.RunAsWorkflow(ctx, typedErasedWorkflow, typedInput, opts...) if err != nil { return nil, err } From 5149aa51b9ae16f3d959393560af510575761ffd Mon Sep 17 00:00:00 2001 From: maxdml Date: Fri, 1 Aug 2025 15:20:20 -0700 Subject: [PATCH 22/30] typo --- dbos/workflow.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dbos/workflow.go b/dbos/workflow.go index 5c22ad75..ddca211e 100644 --- a/dbos/workflow.go +++ b/dbos/workflow.go @@ -481,7 +481,7 @@ func RunAsWorkflow[P any, R any](dbosCtx DBOSContext, fn GenericWorkflowFunc[P, return nil, fmt.Errorf("unexpected workflow handle type: %T", handle) } -func (c *dbosContext) RunAsWorkflow(w_ DBOSContext, fn WorkflowFunc, input any, opts ...WorkflowOption) (WorkflowHandle[any], error) { +func (c *dbosContext) RunAsWorkflow(_ DBOSContext, fn WorkflowFunc, input any, opts ...WorkflowOption) (WorkflowHandle[any], error) { // Apply options to build params params := workflowParams{ applicationVersion: c.GetApplicationVersion(), From cebb49b92b00af38c69a36874a092fab71dabe59 Mon Sep 17 00:00:00 2001 From: maxdml Date: Fri, 1 Aug 2025 15:46:36 -0700 Subject: [PATCH 23/30] fix small race in test + check recovery handles == original handles --- dbos/workflows_test.go | 36 ++++++++++++++++++++++++++++++++---- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/dbos/workflows_test.go b/dbos/workflows_test.go index 064dbfcc..4e4947da 100644 --- a/dbos/workflows_test.go +++ b/dbos/workflows_test.go @@ -1493,7 +1493,8 @@ func TestSendRecv(t *testing.T) { } var ( - setEventIdempotencyEvent = NewEvent() + setEventStartIdempotencyEvent = NewEvent() + setEvenStopIdempotencyEvent = NewEvent() getEventStartIdempotencyEvent = NewEvent() getEventStopIdempotencyEvent = NewEvent() setSecondEventSignal = NewEvent() @@ -1548,7 +1549,8 @@ func setEventIdempotencyWorkflow(ctx DBOSContext, input setEventWorkflowInput) ( if err != nil { return "", err } - setEventIdempotencyEvent.Wait() + setEventStartIdempotencyEvent.Set() + setEvenStopIdempotencyEvent.Wait() return "idempotent-set-completed", nil } @@ -1751,9 +1753,11 @@ func TestSetGetEvent(t *testing.T) { t.Fatalf("failed to start get event idempotency workflow: %v", err) } - // Wait for the get event workflow to signal it has received the event + // Wait for the workflows to signal it has received the event getEventStartIdempotencyEvent.Wait() getEventStartIdempotencyEvent.Clear() + setEventStartIdempotencyEvent.Wait() + setEventStartIdempotencyEvent.Clear() // Attempt recovering both workflows. Each should have exactly 1 step. recoveredHandles, err := recoverPendingWorkflows(executor.(*dbosContext), []string{"local"}) @@ -1765,6 +1769,7 @@ func TestSetGetEvent(t *testing.T) { } getEventStartIdempotencyEvent.Wait() + setEventStartIdempotencyEvent.Wait() // Verify step counts setSteps, err := executor.(*dbosContext).systemDB.GetWorkflowSteps(context.Background(), setHandle.GetWorkflowID()) @@ -1784,7 +1789,7 @@ func TestSetGetEvent(t *testing.T) { } // Complete the workflows - setEventIdempotencyEvent.Set() + setEvenStopIdempotencyEvent.Set() getEventStopIdempotencyEvent.Set() setResult, err := setHandle.GetResult() @@ -1802,6 +1807,29 @@ func TestSetGetEvent(t *testing.T) { if getResult != "idempotency-message" { t.Fatalf("expected result to be 'idempotency-message', got '%s'", getResult) } + + // Check the recovered handle returns the same result + for _, recoveredHandle := range recoveredHandles { + if recoveredHandle.GetWorkflowID() == setHandle.GetWorkflowID() { + recoveredSetResult, err := recoveredHandle.GetResult() + if err != nil { + t.Fatalf("failed to get result from recovered set event idempotency workflow: %v", err) + } + if recoveredSetResult != "idempotent-set-completed" { + t.Fatalf("expected recovered result to be 'idempotent-set-completed', got '%s'", recoveredSetResult) + + } + } + if recoveredHandle.GetWorkflowID() == getHandle.GetWorkflowID() { + recoveredGetResult, err := recoveredHandle.GetResult() + if err != nil { + t.Fatalf("failed to get result from recovered get event idempotency workflow: %v", err) + } + if recoveredGetResult != "idempotency-message" { + t.Fatalf("expected recovered result to be 'idempotency-message', got '%s'", recoveredGetResult) + } + } + } }) t.Run("ConcurrentGetEvent", func(t *testing.T) { From 2a77ad8cbb3be88f85eefc7b56f49d6a516104ec Mon Sep 17 00:00:00 2001 From: maxdml Date: Fri, 1 Aug 2025 15:55:37 -0700 Subject: [PATCH 24/30] nit --- dbos/workflows_test.go | 117 ++++++++++++++++++++--------------------- 1 file changed, 58 insertions(+), 59 deletions(-) diff --git a/dbos/workflows_test.go b/dbos/workflows_test.go index 4e4947da..2b2d01fb 100644 --- a/dbos/workflows_test.go +++ b/dbos/workflows_test.go @@ -95,8 +95,7 @@ func Identity[T any](dbosCtx DBOSContext, in T) (T, error) { } func TestWorkflowsRegistration(t *testing.T) { - executor := setupDBOS(t) - dbosCtx := executor + dbosCtx := setupDBOS(t) // Setup workflows with executor RegisterWorkflow(dbosCtx, simpleWorkflow) @@ -279,7 +278,7 @@ func TestWorkflowsRegistration(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - result, err := tc.workflowFunc(executor, tc.input, WithWorkflowID(uuid.NewString())) + result, err := tc.workflowFunc(dbosCtx, tc.input, WithWorkflowID(uuid.NewString())) if tc.expectError { if err == nil { @@ -639,8 +638,8 @@ func TestWorkflowIdempotency(t *testing.T) { } func TestWorkflowRecovery(t *testing.T) { - executor := setupDBOS(t) - RegisterWorkflow(executor, idempotencyWorkflowWithStep) + dbosCtx := setupDBOS(t) + RegisterWorkflow(dbosCtx, idempotencyWorkflowWithStep) t.Run("RecoveryResumeWhereItLeftOff", func(t *testing.T) { // Reset the global counter idempotencyCounter = 0 @@ -649,7 +648,7 @@ func TestWorkflowRecovery(t *testing.T) { input := "recovery-test" idempotencyWorkflowWithStepEvent = NewEvent() blockingStepStopEvent = NewEvent() - handle1, err := RunAsWorkflow(executor, idempotencyWorkflowWithStep, input) + handle1, err := RunAsWorkflow(dbosCtx, idempotencyWorkflowWithStep, input) if err != nil { t.Fatalf("failed to execute workflow first time: %v", err) } @@ -657,7 +656,7 @@ func TestWorkflowRecovery(t *testing.T) { idempotencyWorkflowWithStepEvent.Wait() // Wait for the first step to complete. The second spins forever. // Run recovery for pending workflows with "local" executor - recoveredHandles, err := recoverPendingWorkflows(executor.(*dbosContext), []string{"local"}) + recoveredHandles, err := recoverPendingWorkflows(dbosCtx.(*dbosContext), []string{"local"}) if err != nil { t.Fatalf("failed to recover pending workflows: %v", err) } @@ -686,7 +685,7 @@ func TestWorkflowRecovery(t *testing.T) { } // Using ListWorkflows, retrieve the status of the workflow - workflows, err := executor.(*dbosContext).systemDB.ListWorkflows(context.Background(), listWorkflowsDBInput{ + workflows, err := dbosCtx.(*dbosContext).systemDB.ListWorkflows(context.Background(), listWorkflowsDBInput{ workflowIDs: []string{handle1.GetWorkflowID()}, }) if err != nil { @@ -932,7 +931,7 @@ func TestScheduledWorkflows(t *testing.T) { err := dbosCtx.Launch() if err != nil { - t.Fatalf("failed to launch executor: %v", err) + t.Fatalf("failed to launch DBOS: %v", err) } // Helper function to collect execution times @@ -1118,27 +1117,27 @@ type sendRecvType struct { } func TestSendRecv(t *testing.T) { - executor := setupDBOS(t) + dbosCtx := setupDBOS(t) // Register all send/recv workflows with executor - RegisterWorkflow(executor, sendWorkflow) - RegisterWorkflow(executor, receiveWorkflow) - RegisterWorkflow(executor, receiveWorkflowCoordinated) - RegisterWorkflow(executor, sendStructWorkflow) - RegisterWorkflow(executor, receiveStructWorkflow) - RegisterWorkflow(executor, sendIdempotencyWorkflow) - RegisterWorkflow(executor, receiveIdempotencyWorkflow) - RegisterWorkflow(executor, workflowThatCallsSendInStep) + RegisterWorkflow(dbosCtx, sendWorkflow) + RegisterWorkflow(dbosCtx, receiveWorkflow) + RegisterWorkflow(dbosCtx, receiveWorkflowCoordinated) + RegisterWorkflow(dbosCtx, sendStructWorkflow) + RegisterWorkflow(dbosCtx, receiveStructWorkflow) + RegisterWorkflow(dbosCtx, sendIdempotencyWorkflow) + RegisterWorkflow(dbosCtx, receiveIdempotencyWorkflow) + RegisterWorkflow(dbosCtx, workflowThatCallsSendInStep) t.Run("SendRecvSuccess", func(t *testing.T) { // Start the receive workflow - receiveHandle, err := RunAsWorkflow(executor, receiveWorkflow, "test-topic") + receiveHandle, err := RunAsWorkflow(dbosCtx, receiveWorkflow, "test-topic") if err != nil { t.Fatalf("failed to start receive workflow: %v", err) } // Send a message to the receive workflow - handle, err := RunAsWorkflow(executor, sendWorkflow, sendWorkflowInput{ + handle, err := RunAsWorkflow(dbosCtx, sendWorkflow, sendWorkflowInput{ DestinationID: receiveHandle.GetWorkflowID(), Topic: "test-topic", }) @@ -1158,21 +1157,21 @@ func TestSendRecv(t *testing.T) { if result != "message1-message2-message3" { t.Fatalf("expected received message to be 'message1-message2-message3', got '%s'", result) } - // XXX let's see how this works when all the tests run - if time.Since(start) > 5*time.Second { + // XXX This is not a great condition: when all the tests run there's quite some randomness to this + if time.Since(start) > 10*time.Second { t.Fatalf("receive workflow took too long to complete, expected < 5s, got %v", time.Since(start)) } }) t.Run("SendRecvCustomStruct", func(t *testing.T) { // Start the receive workflow - receiveHandle, err := RunAsWorkflow(executor, receiveStructWorkflow, "struct-topic") + receiveHandle, err := RunAsWorkflow(dbosCtx, receiveStructWorkflow, "struct-topic") if err != nil { t.Fatalf("failed to start receive workflow: %v", err) } // Send the struct to the receive workflow - sendHandle, err := RunAsWorkflow(executor, sendStructWorkflow, sendWorkflowInput{ + sendHandle, err := RunAsWorkflow(dbosCtx, sendStructWorkflow, sendWorkflowInput{ DestinationID: receiveHandle.GetWorkflowID(), Topic: "struct-topic", }) @@ -1202,7 +1201,7 @@ func TestSendRecv(t *testing.T) { destUUID := uuid.NewString() // Send to non-existent UUID should fail - handle, err := RunAsWorkflow(executor, sendWorkflow, sendWorkflowInput{ + handle, err := RunAsWorkflow(dbosCtx, sendWorkflow, sendWorkflowInput{ DestinationID: destUUID, Topic: "testtopic", }) @@ -1232,7 +1231,7 @@ func TestSendRecv(t *testing.T) { t.Run("RecvTimeout", func(t *testing.T) { // Create a receive workflow that tries to receive a message but no send happens - receiveHandle, err := RunAsWorkflow(executor, receiveWorkflow, "timeout-test-topic") + receiveHandle, err := RunAsWorkflow(dbosCtx, receiveWorkflow, "timeout-test-topic") if err != nil { t.Fatalf("failed to start receive workflow: %v", err) } @@ -1247,7 +1246,7 @@ func TestSendRecv(t *testing.T) { t.Run("RecvMustRunInsideWorkflows", func(t *testing.T) { // Attempt to run Recv outside of a workflow context - _, err := Recv[string](executor, WorkflowRecvInput{Topic: "test-topic", Timeout: 1 * time.Second}) + _, err := Recv[string](dbosCtx, WorkflowRecvInput{Topic: "test-topic", Timeout: 1 * time.Second}) if err == nil { t.Fatal("expected error when running Recv outside of workflow context, but got none") } @@ -1271,14 +1270,14 @@ func TestSendRecv(t *testing.T) { t.Run("SendOutsideWorkflow", func(t *testing.T) { // Start a receive workflow to have a valid destination - receiveHandle, err := RunAsWorkflow(executor, receiveWorkflow, "outside-workflow-topic") + receiveHandle, err := RunAsWorkflow(dbosCtx, receiveWorkflow, "outside-workflow-topic") if err != nil { t.Fatalf("failed to start receive workflow: %v", err) } // Send messages from outside a workflow context (should work now) for i := range 3 { - err = Send(executor, WorkflowSendInput[string]{ + err = Send(dbosCtx, WorkflowSendInput[string]{ DestinationID: receiveHandle.GetWorkflowID(), Topic: "outside-workflow-topic", Message: fmt.Sprintf("message%d", i+1), @@ -1299,13 +1298,13 @@ func TestSendRecv(t *testing.T) { }) t.Run("SendRecvIdempotency", func(t *testing.T) { // Start the receive workflow and wait for it to be ready - receiveHandle, err := RunAsWorkflow(executor, receiveIdempotencyWorkflow, "idempotency-topic") + receiveHandle, err := RunAsWorkflow(dbosCtx, receiveIdempotencyWorkflow, "idempotency-topic") if err != nil { t.Fatalf("failed to start receive idempotency workflow: %v", err) } // Send the message to the receive workflow - sendHandle, err := RunAsWorkflow(executor, sendIdempotencyWorkflow, sendWorkflowInput{ + sendHandle, err := RunAsWorkflow(dbosCtx, sendIdempotencyWorkflow, sendWorkflowInput{ DestinationID: receiveHandle.GetWorkflowID(), Topic: "idempotency-topic", }) @@ -1317,21 +1316,21 @@ func TestSendRecv(t *testing.T) { receiveIdempotencyStartEvent.Wait() // Attempt recovering both workflows. There should be only 2 steps recorded after recovery. - recoveredHandles, err := recoverPendingWorkflows(executor.(*dbosContext), []string{"local"}) + recoveredHandles, err := recoverPendingWorkflows(dbosCtx.(*dbosContext), []string{"local"}) if err != nil { t.Fatalf("failed to recover pending workflows: %v", err) } if len(recoveredHandles) != 2 { t.Fatalf("expected 2 recovered handles, got %d", len(recoveredHandles)) } - steps, err := executor.(*dbosContext).systemDB.GetWorkflowSteps(context.Background(), sendHandle.GetWorkflowID()) + steps, err := dbosCtx.(*dbosContext).systemDB.GetWorkflowSteps(context.Background(), sendHandle.GetWorkflowID()) if err != nil { t.Fatalf("failed to get workflow steps: %v", err) } if len(steps) != 1 { t.Fatalf("expected 1 step in send idempotency workflow, got %d", len(steps)) } - steps, err = executor.(*dbosContext).systemDB.GetWorkflowSteps(context.Background(), receiveHandle.GetWorkflowID()) + steps, err = dbosCtx.(*dbosContext).systemDB.GetWorkflowSteps(context.Background(), receiveHandle.GetWorkflowID()) if err != nil { t.Fatalf("failed to get steps for receive idempotency workflow: %v", err) } @@ -1360,13 +1359,13 @@ func TestSendRecv(t *testing.T) { t.Run("SendCannotBeCalledWithinStep", func(t *testing.T) { // Start a receive workflow to have a valid destination - receiveHandle, err := RunAsWorkflow(executor, receiveWorkflow, "send-within-step-topic") + receiveHandle, err := RunAsWorkflow(dbosCtx, receiveWorkflow, "send-within-step-topic") if err != nil { t.Fatalf("failed to start receive workflow: %v", err) } // Execute the workflow that tries to call Send within a step - handle, err := RunAsWorkflow(executor, workflowThatCallsSendInStep, sendWorkflowInput{ + handle, err := RunAsWorkflow(dbosCtx, workflowThatCallsSendInStep, sendWorkflowInput{ DestinationID: receiveHandle.GetWorkflowID(), Topic: "send-within-step-topic", }) @@ -1420,7 +1419,7 @@ func TestSendRecv(t *testing.T) { // Start all receivers - they will signal when ready and wait for coordination for i := range numReceivers { concurrentRecvReadyEvents[i] = NewEvent() - receiveHandle, err := RunAsWorkflow(executor, receiveWorkflowCoordinated, struct { + receiveHandle, err := RunAsWorkflow(dbosCtx, receiveWorkflowCoordinated, struct { Topic string i int }{ @@ -1569,21 +1568,21 @@ func getEventIdempotencyWorkflow(ctx DBOSContext, input setEventWorkflowInput) ( } func TestSetGetEvent(t *testing.T) { - executor := setupDBOS(t) + dbosCtx := setupDBOS(t) // Register all set/get event workflows with executor - RegisterWorkflow(executor, setEventWorkflow) - RegisterWorkflow(executor, getEventWorkflow) - RegisterWorkflow(executor, setTwoEventsWorkflow) - RegisterWorkflow(executor, setEventIdempotencyWorkflow) - RegisterWorkflow(executor, getEventIdempotencyWorkflow) + RegisterWorkflow(dbosCtx, setEventWorkflow) + RegisterWorkflow(dbosCtx, getEventWorkflow) + RegisterWorkflow(dbosCtx, setTwoEventsWorkflow) + RegisterWorkflow(dbosCtx, setEventIdempotencyWorkflow) + RegisterWorkflow(dbosCtx, getEventIdempotencyWorkflow) t.Run("SetGetEventFromWorkflow", func(t *testing.T) { // Clear the signal event before starting setSecondEventSignal.Clear() // Start the workflow that sets two events - setHandle, err := RunAsWorkflow(executor, setTwoEventsWorkflow, setEventWorkflowInput{ + setHandle, err := RunAsWorkflow(dbosCtx, setTwoEventsWorkflow, setEventWorkflowInput{ Key: "test-workflow", Message: "unused", }) @@ -1592,7 +1591,7 @@ func TestSetGetEvent(t *testing.T) { } // Start a workflow to get the first event - getFirstEventHandle, err := RunAsWorkflow(executor, getEventWorkflow, setEventWorkflowInput{ + getFirstEventHandle, err := RunAsWorkflow(dbosCtx, getEventWorkflow, setEventWorkflowInput{ Key: setHandle.GetWorkflowID(), // Target workflow ID Message: "event1", // Event key }) @@ -1613,7 +1612,7 @@ func TestSetGetEvent(t *testing.T) { setSecondEventSignal.Set() // Start a workflow to get the second event - getSecondEventHandle, err := RunAsWorkflow(executor, getEventWorkflow, setEventWorkflowInput{ + getSecondEventHandle, err := RunAsWorkflow(dbosCtx, getEventWorkflow, setEventWorkflowInput{ Key: setHandle.GetWorkflowID(), // Target workflow ID Message: "event2", // Event key }) @@ -1642,7 +1641,7 @@ func TestSetGetEvent(t *testing.T) { t.Run("GetEventFromOutsideWorkflow", func(t *testing.T) { // Start a workflow that sets an event - setHandle, err := RunAsWorkflow(executor, setEventWorkflow, setEventWorkflowInput{ + setHandle, err := RunAsWorkflow(dbosCtx, setEventWorkflow, setEventWorkflowInput{ Key: "test-key", Message: "test-message", }) @@ -1657,7 +1656,7 @@ func TestSetGetEvent(t *testing.T) { } // Start a workflow that gets the event from outside the original workflow - message, err := GetEvent[string](executor, WorkflowGetEventInput{ + message, err := GetEvent[string](dbosCtx, WorkflowGetEventInput{ TargetWorkflowID: setHandle.GetWorkflowID(), Key: "test-key", Timeout: 3 * time.Second, @@ -1673,7 +1672,7 @@ func TestSetGetEvent(t *testing.T) { t.Run("GetEventTimeout", func(t *testing.T) { // Try to get an event from a non-existent workflow nonExistentID := uuid.NewString() - message, err := GetEvent[string](executor, WorkflowGetEventInput{ + message, err := GetEvent[string](dbosCtx, WorkflowGetEventInput{ TargetWorkflowID: nonExistentID, Key: "test-key", Timeout: 3 * time.Second, @@ -1686,7 +1685,7 @@ func TestSetGetEvent(t *testing.T) { } // Try to get an event from an existing workflow but with a key that doesn't exist - setHandle, err := RunAsWorkflow(executor, setEventWorkflow, setEventWorkflowInput{ + setHandle, err := RunAsWorkflow(dbosCtx, setEventWorkflow, setEventWorkflowInput{ Key: "test-key", Message: "test-message", }) @@ -1697,7 +1696,7 @@ func TestSetGetEvent(t *testing.T) { if err != nil { t.Fatal("failed to get result from set event workflow:", err) } - message, err = GetEvent[string](executor, WorkflowGetEventInput{ + message, err = GetEvent[string](dbosCtx, WorkflowGetEventInput{ TargetWorkflowID: setHandle.GetWorkflowID(), Key: "non-existent-key", Timeout: 3 * time.Second, @@ -1712,7 +1711,7 @@ func TestSetGetEvent(t *testing.T) { t.Run("SetGetEventMustRunInsideWorkflows", func(t *testing.T) { // Attempt to run SetEvent outside of a workflow context - err := SetEvent(executor, WorkflowSetEventInputGeneric[string]{Key: "test-key", Message: "test-message"}) + err := SetEvent(dbosCtx, WorkflowSetEventInputGeneric[string]{Key: "test-key", Message: "test-message"}) if err == nil { t.Fatal("expected error when running SetEvent outside of workflow context, but got none") } @@ -1736,7 +1735,7 @@ func TestSetGetEvent(t *testing.T) { t.Run("SetGetEventIdempotency", func(t *testing.T) { // Start the set event workflow - setHandle, err := RunAsWorkflow(executor, setEventIdempotencyWorkflow, setEventWorkflowInput{ + setHandle, err := RunAsWorkflow(dbosCtx, setEventIdempotencyWorkflow, setEventWorkflowInput{ Key: "idempotency-key", Message: "idempotency-message", }) @@ -1745,7 +1744,7 @@ func TestSetGetEvent(t *testing.T) { } // Start the get event workflow - getHandle, err := RunAsWorkflow(executor, getEventIdempotencyWorkflow, setEventWorkflowInput{ + getHandle, err := RunAsWorkflow(dbosCtx, getEventIdempotencyWorkflow, setEventWorkflowInput{ Key: setHandle.GetWorkflowID(), Message: "idempotency-key", }) @@ -1760,7 +1759,7 @@ func TestSetGetEvent(t *testing.T) { setEventStartIdempotencyEvent.Clear() // Attempt recovering both workflows. Each should have exactly 1 step. - recoveredHandles, err := recoverPendingWorkflows(executor.(*dbosContext), []string{"local"}) + recoveredHandles, err := recoverPendingWorkflows(dbosCtx.(*dbosContext), []string{"local"}) if err != nil { t.Fatalf("failed to recover pending workflows: %v", err) } @@ -1772,7 +1771,7 @@ func TestSetGetEvent(t *testing.T) { setEventStartIdempotencyEvent.Wait() // Verify step counts - setSteps, err := executor.(*dbosContext).systemDB.GetWorkflowSteps(context.Background(), setHandle.GetWorkflowID()) + setSteps, err := dbosCtx.(*dbosContext).systemDB.GetWorkflowSteps(context.Background(), setHandle.GetWorkflowID()) if err != nil { t.Fatalf("failed to get steps for set event idempotency workflow: %v", err) } @@ -1780,7 +1779,7 @@ func TestSetGetEvent(t *testing.T) { t.Fatalf("expected 1 step in set event idempotency workflow, got %d", len(setSteps)) } - getSteps, err := executor.(*dbosContext).systemDB.GetWorkflowSteps(context.Background(), getHandle.GetWorkflowID()) + getSteps, err := dbosCtx.(*dbosContext).systemDB.GetWorkflowSteps(context.Background(), getHandle.GetWorkflowID()) if err != nil { t.Fatalf("failed to get steps for get event idempotency workflow: %v", err) } @@ -1834,7 +1833,7 @@ func TestSetGetEvent(t *testing.T) { t.Run("ConcurrentGetEvent", func(t *testing.T) { // Set event - setHandle, err := RunAsWorkflow(executor, setEventWorkflow, setEventWorkflowInput{ + setHandle, err := RunAsWorkflow(dbosCtx, setEventWorkflow, setEventWorkflowInput{ Key: "concurrent-event-key", Message: "concurrent-event-message", }) @@ -1855,7 +1854,7 @@ func TestSetGetEvent(t *testing.T) { for range numGoroutines { go func() { defer wg.Done() - res, err := GetEvent[string](executor, WorkflowGetEventInput{ + res, err := GetEvent[string](dbosCtx, WorkflowGetEventInput{ TargetWorkflowID: setHandle.GetWorkflowID(), Key: "concurrent-event-key", Timeout: 10 * time.Second, From aad3c378718ea93d8c0fb2e277916581f1fb67fe Mon Sep 17 00:00:00 2001 From: maxdml Date: Fri, 1 Aug 2025 16:28:28 -0700 Subject: [PATCH 25/30] Revert changes to step interface for now --- dbos/dbos.go | 2 +- dbos/queues_test.go | 11 ++------ dbos/serialization_test.go | 25 +++------------- dbos/workflow.go | 16 +++++------ dbos/workflows_test.go | 58 +++++++++++++++----------------------- 5 files changed, 37 insertions(+), 75 deletions(-) diff --git a/dbos/dbos.go b/dbos/dbos.go index 454156b1..f2090ef2 100644 --- a/dbos/dbos.go +++ b/dbos/dbos.go @@ -68,7 +68,7 @@ type DBOSContext interface { Shutdown() // Workflow operations - RunAsStep(_ DBOSContext, fn StepFunc, input ...any) (any, error) + RunAsStep(_ DBOSContext, fn StepFunc, input any) (any, error) RunAsWorkflow(_ DBOSContext, fn WorkflowFunc, input any, opts ...WorkflowOption) (WorkflowHandle[any], error) Send(_ DBOSContext, input WorkflowSendInputInternal) error Recv(_ DBOSContext, input WorkflowRecvInput) (any, error) diff --git a/dbos/queues_test.go b/dbos/queues_test.go index 813632b3..488dbca2 100644 --- a/dbos/queues_test.go +++ b/dbos/queues_test.go @@ -44,15 +44,8 @@ func queueWorkflow(ctx DBOSContext, input string) (string, error) { return step1, nil } -func queueStep(ctx context.Context, input ...any) (string, error) { - if len(input) == 0 { - return "", nil - } - str, ok := input[0].(string) - if !ok { - return "", fmt.Errorf("expected string input, got %T", input[0]) - } - return str, nil +func queueStep(ctx context.Context, input string) (string, error) { + return input, nil } func TestWorkflowQueues(t *testing.T) { diff --git a/dbos/serialization_test.go b/dbos/serialization_test.go index a7129c14..d10e0f87 100644 --- a/dbos/serialization_test.go +++ b/dbos/serialization_test.go @@ -18,18 +18,8 @@ import ( */ // Builtin types -func encodingStepBuiltinTypes(_ context.Context, input ...any) (int, error) { - if len(input) == 0 { - fmt.Println("No input provided to encodingStepBuiltinTypes") - return 0, errors.New("step error") - } - val, ok := input[0].(int) - fmt.Println("Input to encodingStepBuiltinTypes:", val, "ok:", ok) - if !ok { - return 0, errors.New("step error") - } - fmt.Println("Processing input in encodingStepBuiltinTypes:", val) - return val, errors.New("step error") +func encodingStepBuiltinTypes(_ context.Context, input int) (int, error) { + return input, errors.New("step error") } func encodingWorkflowBuiltinTypes(ctx DBOSContext, input string) (string, error) { @@ -65,16 +55,9 @@ func encodingWorkflowStruct(ctx DBOSContext, input WorkflowInputStruct) (StepOut }) } -func encodingStepStruct(ctx context.Context, input ...any) (StepOutputStruct, error) { - if len(input) == 0 { - return StepOutputStruct{}, nil - } - stepInput, ok := input[0].(StepInputStruct) - if !ok { - return StepOutputStruct{}, nil - } +func encodingStepStruct(ctx context.Context, input StepInputStruct) (StepOutputStruct, error) { return StepOutputStruct{ - A: stepInput, + A: input, B: "processed by encodingStepStruct", }, nil } diff --git a/dbos/workflow.go b/dbos/workflow.go index ddca211e..e32aa671 100644 --- a/dbos/workflow.go +++ b/dbos/workflow.go @@ -642,8 +642,8 @@ func (c *dbosContext) RunAsWorkflow(_ DBOSContext, fn WorkflowFunc, input any, o /******* STEP FUNCTIONS *******/ /******************************/ -type StepFunc func(ctx context.Context, input ...any) (any, error) -type GenericStepFunc[R any] func(ctx context.Context, input ...any) (R, error) +type StepFunc func(ctx context.Context, input any) (any, error) +type GenericStepFunc[P any, R any] func(ctx context.Context, input P) (R, error) const StepParamsKey DBOSContextKey = "stepParams" @@ -687,7 +687,7 @@ func setStepParamDefaults(params *StepParams, stepName string) *StepParams { var typeErasedStepNameToStepName = make(map[string]string) -func RunAsStep[R any](dbosCtx DBOSContext, fn GenericStepFunc[R], input ...any) (R, error) { +func RunAsStep[P any, R any](dbosCtx DBOSContext, fn GenericStepFunc[P, R], input P) (R, error) { if dbosCtx == nil { return *new(R), newStepExecutionError("", "", "dbosCtx cannot be nil") } @@ -697,14 +697,14 @@ func RunAsStep[R any](dbosCtx DBOSContext, fn GenericStepFunc[R], input ...any) } // Type-erase the function based on its actual type - typeErasedFn := StepFunc(func(ctx context.Context, i ...any) (any, error) { - return fn(ctx, i...) + typeErasedFn := StepFunc(func(ctx context.Context, input any) (any, error) { + return fn(ctx, input.(P)) }) typeErasedStepNameToStepName[runtime.FuncForPC(reflect.ValueOf(typeErasedFn).Pointer()).Name()] = runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name() // Call the executor method - result, err := dbosCtx.RunAsStep(dbosCtx, typeErasedFn, input...) + result, err := dbosCtx.RunAsStep(dbosCtx, typeErasedFn, input) if err != nil { // In case the errors comes from the DBOS step logic, the result will be nil and we must handle it if result == nil { @@ -722,7 +722,7 @@ func RunAsStep[R any](dbosCtx DBOSContext, fn GenericStepFunc[R], input ...any) return typedResult, nil } -func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc, input ...any) (any, error) { +func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc, input any) (any, error) { // Get workflow state from context wfState, ok := c.ctx.Value(workflowStateKey).(*workflowState) if !ok || wfState == nil { @@ -770,7 +770,7 @@ func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc, input ...any) (any, // Spawn a child DBOSContext with the step state stepCtx := WithValue(c, workflowStateKey, &stepState) - stepOutput, stepError := fn(stepCtx, input...) + stepOutput, stepError := fn(stepCtx, input) // Retry if MaxRetries > 0 and the first execution failed var joinedErrors error diff --git a/dbos/workflows_test.go b/dbos/workflows_test.go index 2b2d01fb..239da430 100644 --- a/dbos/workflows_test.go +++ b/dbos/workflows_test.go @@ -35,12 +35,11 @@ func simpleWorkflowWithStep(dbosCtx DBOSContext, input string) (string, error) { return RunAsStep(dbosCtx, simpleStep, input) } -func simpleStep(ctx context.Context, input ...any) (string, error) { - fmt.Println("simpleStep called with input:", input) +func simpleStep(ctx context.Context, input string) (string, error) { return "from step", nil } -func simpleStepError(ctx context.Context, input ...any) (string, error) { +func simpleStepError(ctx context.Context, input string) (string, error) { return "", fmt.Errorf("step failure") } @@ -49,17 +48,8 @@ func simpleWorkflowWithStepError(dbosCtx DBOSContext, input string) (string, err } // idempotencyWorkflow increments a global counter and returns the input -func incrementCounter(_ context.Context, value ...any) (int64, error) { - if len(value) == 0 { - return 0, fmt.Errorf("expected int64 value") - } - val, ok := value[0].(int64) - if !ok { - return 0, fmt.Errorf("expected int64, got %T", value[0]) - } - fmt.Println("incrementCounter called with value:", val) - idempotencyCounter += val - fmt.Println("Current idempotency counter:", idempotencyCounter) +func incrementCounter(_ context.Context, value int64) (int64, error) { + idempotencyCounter += value return idempotencyCounter, nil } @@ -299,8 +289,8 @@ func TestWorkflowsRegistration(t *testing.T) { } } -func stepWithinAStep(ctx context.Context, input ...any) (string, error) { - return simpleStep(ctx, input...) +func stepWithinAStep(ctx context.Context, input string) (string, error) { + return simpleStep(ctx, input) } func stepWithinAStepWorkflow(dbosCtx DBOSContext, input string) (string, error) { @@ -310,20 +300,15 @@ func stepWithinAStepWorkflow(dbosCtx DBOSContext, input string) (string, error) // Global counter for retry testing var stepRetryAttemptCount int -func stepRetryAlwaysFailsStep(ctx context.Context, input ...any) (string, error) { +func stepRetryAlwaysFailsStep(ctx context.Context, input string) (string, error) { stepRetryAttemptCount++ return "", fmt.Errorf("always fails - attempt %d", stepRetryAttemptCount) } var stepIdempotencyCounter int -func stepIdempotencyTest(ctx context.Context, input ...any) (string, error) { +func stepIdempotencyTest(ctx context.Context, input int) (string, error) { stepIdempotencyCounter++ - if len(input) > 0 { - if str, ok := input[0].(string); ok { - return str, nil - } - } return "", nil } @@ -335,7 +320,7 @@ func stepRetryWorkflow(dbosCtx DBOSContext, input string) (string, error) { MaxInterval: 10 * time.Millisecond, }) - return RunAsStep[string](stepCtx, stepRetryAlwaysFailsStep, input) + return RunAsStep(stepCtx, stepRetryAlwaysFailsStep, input) } // TODO: step params @@ -574,7 +559,7 @@ func idempotencyWorkflow(dbosCtx DBOSContext, input string) (string, error) { var blockingStepStopEvent *Event -func blockingStep(ctx context.Context, input ...any) (string, error) { +func blockingStep(ctx context.Context, input string) (string, error) { blockingStepStopEvent.Wait() return "", nil } @@ -619,6 +604,10 @@ func TestWorkflowIdempotency(t *testing.T) { t.Fatalf("failed to get result from second execution: %v", err) } + if handle1.GetWorkflowID() != handle2.GetWorkflowID() { + t.Fatalf("expected both handles to represent the same workflow ID, got %s and %s", handle2.GetWorkflowID(), handle1.GetWorkflowID()) + } + // Verify the second handle is a polling handle _, ok := handle2.(*workflowPollingHandle[string]) if !ok { @@ -751,7 +740,6 @@ func TestWorkflowDeadLetterQueue(t *testing.T) { // Start a workflow that blocks forever wfID := uuid.NewString() - fmt.Println(wfID) handle, err := RunAsWorkflow(dbosCtx, deadLetterQueueWorkflow, "test", WithWorkflowID(wfID)) if err != nil { t.Fatalf("failed to start dead letter queue workflow: %v", err) @@ -1089,17 +1077,10 @@ func receiveIdempotencyWorkflow(ctx DBOSContext, topic string) (string, error) { return msg, nil } -func stepThatCallsSend(ctx context.Context, input ...any) (string, error) { - if len(input) == 0 { - return "", fmt.Errorf("expected sendWorkflowInput") - } - i, ok := input[0].(sendWorkflowInput) - if !ok { - return "", fmt.Errorf("expected sendWorkflowInput, got %T", input[0]) - } +func stepThatCallsSend(ctx context.Context, input sendWorkflowInput) (string, error) { err := Send(ctx.(DBOSContext), WorkflowSendInput[string]{ - DestinationID: i.DestinationID, - Topic: i.Topic, + DestinationID: input.DestinationID, + Topic: input.Topic, Message: "message-from-step", }) if err != nil { @@ -1948,6 +1929,11 @@ func TestSleep(t *testing.T) { } sleepStopEvent.Set() + + _, err = handle.GetResult() + if err != nil { + t.Fatalf("failed to get sleep workflow result: %v", err) + } }) t.Run("SleepCannotBeCalledOutsideWorkflow", func(t *testing.T) { From a86463f2852e52617869108ae344951bb285b38f Mon Sep 17 00:00:00 2001 From: maxdml Date: Fri, 1 Aug 2025 17:11:36 -0700 Subject: [PATCH 26/30] nit --- dbos/admin_server.go | 4 +- dbos/admin_server_test.go | 3 -- dbos/recovery.go | 14 +++---- dbos/workflow.go | 80 +++++++++++++++++++-------------------- 4 files changed, 49 insertions(+), 52 deletions(-) diff --git a/dbos/admin_server.go b/dbos/admin_server.go index 364f3960..c4b4b27d 100644 --- a/dbos/admin_server.go +++ b/dbos/admin_server.go @@ -25,7 +25,7 @@ type queueMetadata struct { RateLimit *RateLimiter `json:"rateLimit,omitempty"` } -func newAdminServer(dbosCtx *dbosContext, port int) *adminServer { +func newAdminServer(ctx *dbosContext, port int) *adminServer { mux := http.NewServeMux() // Health endpoint @@ -50,7 +50,7 @@ func newAdminServer(dbosCtx *dbosContext, port int) *adminServer { getLogger().Info("Recovering workflows for executors", "executors", executorIDs) - handles, err := recoverPendingWorkflows(dbosCtx, executorIDs) + handles, err := recoverPendingWorkflows(ctx, executorIDs) if err != nil { getLogger().Error("Error recovering workflows", "error", err) http.Error(w, fmt.Sprintf("Recovery failed: %v", err), http.StatusInternalServerError) diff --git a/dbos/admin_server_test.go b/dbos/admin_server_test.go index d7eb2a2a..5d11f8b1 100644 --- a/dbos/admin_server_test.go +++ b/dbos/admin_server_test.go @@ -56,9 +56,6 @@ func TestAdminServer(t *testing.T) { }) t.Run("Admin server endpoints", func(t *testing.T) { - // Clean up any existing instance - // (This will be handled by the individual executor cleanup) - // Launch DBOS with admin server once for all endpoint tests ctx, err := NewDBOSContext(Config{ DatabaseURL: databaseURL, diff --git a/dbos/recovery.go b/dbos/recovery.go index bfa74d7a..21d081cb 100644 --- a/dbos/recovery.go +++ b/dbos/recovery.go @@ -4,13 +4,13 @@ import ( "strings" ) -func recoverPendingWorkflows(dbosCtx *dbosContext, executorIDs []string) ([]WorkflowHandle[any], error) { +func recoverPendingWorkflows(ctx *dbosContext, executorIDs []string) ([]WorkflowHandle[any], error) { workflowHandles := make([]WorkflowHandle[any], 0) // List pending workflows for the executors - pendingWorkflows, err := dbosCtx.systemDB.ListWorkflows(dbosCtx.ctx, listWorkflowsDBInput{ + pendingWorkflows, err := ctx.systemDB.ListWorkflows(ctx, listWorkflowsDBInput{ status: []WorkflowStatusType{WorkflowStatusPending}, executorIDs: executorIDs, - applicationVersion: dbosCtx.applicationVersion, + applicationVersion: ctx.applicationVersion, }) if err != nil { return nil, err @@ -25,18 +25,18 @@ func recoverPendingWorkflows(dbosCtx *dbosContext, executorIDs []string) ([]Work } if workflow.QueueName != "" { - cleared, err := dbosCtx.systemDB.ClearQueueAssignment(dbosCtx.ctx, workflow.ID) + cleared, err := ctx.systemDB.ClearQueueAssignment(ctx.ctx, workflow.ID) if err != nil { getLogger().Error("Error clearing queue assignment for workflow", "workflow_id", workflow.ID, "name", workflow.Name, "error", err) continue } if cleared { - workflowHandles = append(workflowHandles, &workflowPollingHandle[any]{workflowID: workflow.ID, dbosContext: dbosCtx}) + workflowHandles = append(workflowHandles, &workflowPollingHandle[any]{workflowID: workflow.ID, dbosContext: ctx}) } continue } - registeredWorkflow, exists := dbosCtx.workflowRegistry[workflow.Name] + registeredWorkflow, exists := ctx.workflowRegistry[workflow.Name] if !exists { getLogger().Error("Workflow function not found in registry", "workflow_id", workflow.ID, "name", workflow.Name) continue @@ -47,7 +47,7 @@ func recoverPendingWorkflows(dbosCtx *dbosContext, executorIDs []string) ([]Work WithWorkflowID(workflow.ID), } // Create a workflow context from the executor context - handle, err := registeredWorkflow.wrappedFunction(dbosCtx, workflow.Input, opts...) + handle, err := registeredWorkflow.wrappedFunction(ctx, workflow.Input, opts...) if err != nil { return nil, err } diff --git a/dbos/workflow.go b/dbos/workflow.go index e32aa671..28a76c37 100644 --- a/dbos/workflow.go +++ b/dbos/workflow.go @@ -200,7 +200,7 @@ func (h *workflowPollingHandle[R]) GetWorkflowID() string { /**********************************/ /******* WORKFLOW REGISTRY *******/ /**********************************/ -type GenericWrappedWorkflowFunc[P any, R any] func(dbosCtx DBOSContext, input P, opts ...WorkflowOption) (WorkflowHandle[R], error) +type GenericWrappedWorkflowFunc[P any, R any] func(ctx DBOSContext, input P, opts ...WorkflowOption) (WorkflowHandle[R], error) type WrappedWorkflowFunc func(ctx DBOSContext, input any, opts ...WorkflowOption) (WorkflowHandle[any], error) type workflowRegistryEntry struct { @@ -209,9 +209,9 @@ type workflowRegistryEntry struct { } // Register adds a workflow function to the registry (thread-safe, only once per name) -func registerWorkflow(dbosCtx DBOSContext, workflowName string, fn WrappedWorkflowFunc, maxRetries int) { +func registerWorkflow(ctx DBOSContext, workflowName string, fn WrappedWorkflowFunc, maxRetries int) { // Skip if we don't have a concrete dbosContext - c, ok := dbosCtx.(*dbosContext) + c, ok := ctx.(*dbosContext) if !ok { return } @@ -230,9 +230,9 @@ func registerWorkflow(dbosCtx DBOSContext, workflowName string, fn WrappedWorkfl } } -func registerScheduledWorkflow(dbosCtx DBOSContext, workflowName string, fn WorkflowFunc, cronSchedule string) { +func registerScheduledWorkflow(ctx DBOSContext, workflowName string, fn WorkflowFunc, cronSchedule string) { // Skip if we don't have a concrete dbosContext - c, ok := dbosCtx.(*dbosContext) + c, ok := ctx.(*dbosContext) if !ok { return } @@ -257,7 +257,7 @@ func registerScheduledWorkflow(dbosCtx DBOSContext, workflowName string, fn Work WithQueue(_DBOS_INTERNAL_QUEUE_NAME), withWorkflowName(workflowName), } - dbosCtx.RunAsWorkflow(dbosCtx, fn, scheduledTime, opts...) + ctx.RunAsWorkflow(ctx, fn, scheduledTime, opts...) }) if err != nil { panic(fmt.Sprintf("failed to register scheduled workflow: %v", err)) @@ -293,9 +293,9 @@ func WithSchedule(schedule string) workflowRegistrationOption { // RegisterWorkflow is generically typed, allowing us to register the workflow input and output types for gob encoding // The registered workflow is wrapped in a typed-erased wrapper which performs runtime type checks and conversions // To execute the workflow, use DBOSContext.RunAsWorkflow -func RegisterWorkflow[P any, R any](dbosCtx DBOSContext, fn GenericWorkflowFunc[P, R], opts ...workflowRegistrationOption) { - if dbosCtx == nil { - panic("dbosCtx cannot be nil") +func RegisterWorkflow[P any, R any](ctx DBOSContext, fn GenericWorkflowFunc[P, R], opts ...workflowRegistrationOption) { + if ctx == nil { + panic("ctx cannot be nil") } if fn == nil { @@ -334,21 +334,21 @@ func RegisterWorkflow[P any, R any](dbosCtx DBOSContext, fn GenericWorkflowFunc[ return nil, newWorkflowUnexpectedInputType(fqn, fmt.Sprintf("%T", typedInput), fmt.Sprintf("%T", input)) } - opts = append(opts, withWorkflowName(fqn)) // Append the name so dbosCtx.RunAsWorkflow can look it up from the registry to apply registration-time options + opts = append(opts, withWorkflowName(fqn)) // Append the name so ctx.RunAsWorkflow can look it up from the registry to apply registration-time options handle, err := ctx.RunAsWorkflow(ctx, typedErasedWorkflow, typedInput, opts...) if err != nil { return nil, err } return &workflowPollingHandle[any]{workflowID: handle.GetWorkflowID(), dbosContext: ctx}, nil }) - registerWorkflow(dbosCtx, fqn, typeErasedWrapper, registrationParams.maxRetries) + registerWorkflow(ctx, fqn, typeErasedWrapper, registrationParams.maxRetries) // If this is a scheduled workflow, register a cron job if registrationParams.cronSchedule != "" { if reflect.TypeOf(p) != reflect.TypeOf(time.Time{}) { panic(fmt.Sprintf("scheduled workflow function must accept a time.Time as input, got %T", p)) } - registerScheduledWorkflow(dbosCtx, fqn, typedErasedWorkflow, registrationParams.cronSchedule) + registerScheduledWorkflow(ctx, fqn, typedErasedWorkflow, registrationParams.cronSchedule) } } @@ -417,9 +417,9 @@ func withWorkflowName(name string) WorkflowOption { } } -func RunAsWorkflow[P any, R any](dbosCtx DBOSContext, fn GenericWorkflowFunc[P, R], input P, opts ...WorkflowOption) (WorkflowHandle[R], error) { - if dbosCtx == nil { - return nil, fmt.Errorf("dbosCtx cannot be nil") +func RunAsWorkflow[P any, R any](ctx DBOSContext, fn GenericWorkflowFunc[P, R], input P, opts ...WorkflowOption) (WorkflowHandle[R], error) { + if ctx == nil { + return nil, fmt.Errorf("ctx cannot be nil") } // Add the fn name to the options so we can communicate it with DBOSContext.RunAsWorkflow @@ -429,7 +429,7 @@ func RunAsWorkflow[P any, R any](dbosCtx DBOSContext, fn GenericWorkflowFunc[P, return fn(ctx, input.(P)) }) - handle, err := dbosCtx.(*dbosContext).RunAsWorkflow(dbosCtx, typedErasedWorkflow, input, opts...) + handle, err := ctx.(*dbosContext).RunAsWorkflow(ctx, typedErasedWorkflow, input, opts...) if err != nil { return nil, err } @@ -439,7 +439,7 @@ func RunAsWorkflow[P any, R any](dbosCtx DBOSContext, fn GenericWorkflowFunc[P, // We need to convert the polling handle to a typed handle typedPollingHandle := &workflowPollingHandle[R]{ workflowID: pollingHandle.workflowID, - dbosContext: dbosCtx, + dbosContext: ctx, } return typedPollingHandle, nil } @@ -687,9 +687,9 @@ func setStepParamDefaults(params *StepParams, stepName string) *StepParams { var typeErasedStepNameToStepName = make(map[string]string) -func RunAsStep[P any, R any](dbosCtx DBOSContext, fn GenericStepFunc[P, R], input P) (R, error) { - if dbosCtx == nil { - return *new(R), newStepExecutionError("", "", "dbosCtx cannot be nil") +func RunAsStep[P any, R any](ctx DBOSContext, fn GenericStepFunc[P, R], input P) (R, error) { + if ctx == nil { + return *new(R), newStepExecutionError("", "", "ctx cannot be nil") } if fn == nil { @@ -704,7 +704,7 @@ func RunAsStep[P any, R any](dbosCtx DBOSContext, fn GenericStepFunc[P, R], inpu typeErasedStepNameToStepName[runtime.FuncForPC(reflect.ValueOf(typeErasedFn).Pointer()).Name()] = runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name() // Call the executor method - result, err := dbosCtx.RunAsStep(dbosCtx, typeErasedFn, input) + result, err := ctx.RunAsStep(ctx, typeErasedFn, input) if err != nil { // In case the errors comes from the DBOS step logic, the result will be nil and we must handle it if result == nil { @@ -846,13 +846,13 @@ func (c *dbosContext) Send(_ DBOSContext, input WorkflowSendInputInternal) error // Send sends a message to another workflow. // Send automatically registers the type of R for gob encoding -func Send[R any](dbosCtx DBOSContext, input WorkflowSendInput[R]) error { - if dbosCtx == nil { - return errors.New("dbosCtx cannot be nil") +func Send[R any](ctx DBOSContext, input WorkflowSendInput[R]) error { + if ctx == nil { + return errors.New("ctx cannot be nil") } var typedMessage R gob.Register(typedMessage) - return dbosCtx.Send(dbosCtx, WorkflowSendInputInternal{ + return ctx.Send(ctx, WorkflowSendInputInternal{ DestinationID: input.DestinationID, Message: input.Message, Topic: input.Topic, @@ -868,11 +868,11 @@ func (c *dbosContext) Recv(_ DBOSContext, input WorkflowRecvInput) (any, error) return c.systemDB.Recv(c.ctx, input) } -func Recv[R any](dbosCtx DBOSContext, input WorkflowRecvInput) (R, error) { - if dbosCtx == nil { - return *new(R), errors.New("dbosCtx cannot be nil") +func Recv[R any](ctx DBOSContext, input WorkflowRecvInput) (R, error) { + if ctx == nil { + return *new(R), errors.New("ctx cannot be nil") } - msg, err := dbosCtx.Recv(dbosCtx, input) + msg, err := ctx.Recv(ctx, input) if err != nil { return *new(R), err } @@ -900,13 +900,13 @@ func (c *dbosContext) SetEvent(_ DBOSContext, input WorkflowSetEventInput) error // Sets an event from a workflow. // The event is a key value pair // SetEvent automatically registers the type of R for gob encoding -func SetEvent[R any](dbosCtx DBOSContext, input WorkflowSetEventInputGeneric[R]) error { - if dbosCtx == nil { - return errors.New("dbosCtx cannot be nil") +func SetEvent[R any](ctx DBOSContext, input WorkflowSetEventInputGeneric[R]) error { + if ctx == nil { + return errors.New("ctx cannot be nil") } var typedMessage R gob.Register(typedMessage) - return dbosCtx.SetEvent(dbosCtx, WorkflowSetEventInput{ + return ctx.SetEvent(ctx, WorkflowSetEventInput{ Key: input.Key, Message: input.Message, }) @@ -922,11 +922,11 @@ func (c *dbosContext) GetEvent(_ DBOSContext, input WorkflowGetEventInput) (any, return c.systemDB.GetEvent(c.ctx, input) } -func GetEvent[R any](dbosCtx DBOSContext, input WorkflowGetEventInput) (R, error) { - if dbosCtx == nil { +func GetEvent[R any](ctx DBOSContext, input WorkflowGetEventInput) (R, error) { + if ctx == nil { return *new(R), errors.New("dbosCtx cannot be nil") } - value, err := dbosCtx.GetEvent(dbosCtx, input) + value, err := ctx.GetEvent(ctx, input) if err != nil { return *new(R), err } @@ -971,11 +971,11 @@ func (c *dbosContext) RetrieveWorkflow(_ DBOSContext, workflowID string) (Workfl return &workflowPollingHandle[any]{workflowID: workflowID, dbosContext: c}, nil } -func RetrieveWorkflow[R any](dbosCtx DBOSContext, workflowID string) (workflowPollingHandle[R], error) { - if dbosCtx == nil { +func RetrieveWorkflow[R any](ctx DBOSContext, workflowID string) (workflowPollingHandle[R], error) { + if ctx == nil { return workflowPollingHandle[R]{}, errors.New("dbosCtx cannot be nil") } - workflowStatus, err := dbosCtx.(*dbosContext).systemDB.ListWorkflows(dbosCtx.(*dbosContext).ctx, listWorkflowsDBInput{ + workflowStatus, err := ctx.(*dbosContext).systemDB.ListWorkflows(ctx.(*dbosContext).ctx, listWorkflowsDBInput{ workflowIDs: []string{workflowID}, }) if err != nil { @@ -984,5 +984,5 @@ func RetrieveWorkflow[R any](dbosCtx DBOSContext, workflowID string) (workflowPo if len(workflowStatus) == 0 { return workflowPollingHandle[R]{}, newNonExistentWorkflowError(workflowID) } - return workflowPollingHandle[R]{workflowID: workflowID, dbosContext: dbosCtx}, nil + return workflowPollingHandle[R]{workflowID: workflowID, dbosContext: ctx}, nil } From 92fe03f4c9080a219fd5415e0020a0cedd9b9b42 Mon Sep 17 00:00:00 2001 From: maxdml Date: Fri, 1 Aug 2025 17:15:13 -0700 Subject: [PATCH 27/30] revert unwanted changes --- dbos/system_database.go | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/dbos/system_database.go b/dbos/system_database.go index 9ceb7551..39db27e3 100644 --- a/dbos/system_database.go +++ b/dbos/system_database.go @@ -251,7 +251,6 @@ func (s *systemDatabase) InsertWorkflowStatus(ctx context.Context, input insertW if input.tx == nil { return nil, errors.New("transaction is required for InsertWorkflowStatus") } - tx := input.tx // Set default values attempts := 1 @@ -313,8 +312,8 @@ func (s *systemDatabase) InsertWorkflowStatus(ctx context.Context, input insertW END RETURNING recovery_attempts, status, name, queue_name, workflow_deadline_epoch_ms` - result := insertWorkflowResult{} - err = tx.QueryRow(ctx, query, + var result insertWorkflowResult + err = input.tx.QueryRow(ctx, query, input.status.ID, input.status.Status, input.status.Name, @@ -363,7 +362,7 @@ func (s *systemDatabase) InsertWorkflowStatus(ctx context.Context, input insertW SET status = $1, deduplication_id = NULL, started_at_epoch_ms = NULL, queue_name = NULL WHERE workflow_uuid = $2 AND status = $3` - _, err = tx.Exec(ctx, dlqQuery, + _, err = input.tx.Exec(ctx, dlqQuery, WorkflowStatusRetriesExceeded, input.status.ID, WorkflowStatusPending) @@ -373,7 +372,7 @@ func (s *systemDatabase) InsertWorkflowStatus(ctx context.Context, input insertW } // Commit the transaction before throwing the error - if err := tx.Commit(ctx); err != nil { + if err := input.tx.Commit(ctx); err != nil { return nil, fmt.Errorf("failed to commit transaction after marking workflow as RETRIES_EXCEEDED: %w", err) } From 38aec116ed1461ff286478b9f32c77cfc5fcc790 Mon Sep 17 00:00:00 2001 From: maxdml Date: Fri, 1 Aug 2025 17:35:24 -0700 Subject: [PATCH 28/30] remove unused WithWorkflowMaxRetries --- dbos/workflow.go | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/dbos/workflow.go b/dbos/workflow.go index 28a76c37..f9da8859 100644 --- a/dbos/workflow.go +++ b/dbos/workflow.go @@ -290,7 +290,7 @@ func WithSchedule(schedule string) workflowRegistrationOption { // RegisterWorkflow registers the provided function as a durable workflow with the provided DBOSContext workflow registry // If the workflow is a scheduled workflow (determined by the presence of a cron schedule), it will also register a cron job to execute it -// RegisterWorkflow is generically typed, allowing us to register the workflow input and output types for gob encoding +// RegisterWorkflow is generically typed, providing compile-time type checking and allowing us to register the workflow input and output types for gob encoding // The registered workflow is wrapped in a typed-erased wrapper which performs runtime type checks and conversions // To execute the workflow, use DBOSContext.RunAsWorkflow func RegisterWorkflow[P any, R any](ctx DBOSContext, fn GenericWorkflowFunc[P, R], opts ...workflowRegistrationOption) { @@ -405,12 +405,6 @@ func WithApplicationVersion(version string) WorkflowOption { } } -func WithWorkflowMaxRetries(maxRetries int) WorkflowOption { - return func(p *workflowParams) { - p.maxRetries = maxRetries - } -} - func withWorkflowName(name string) WorkflowOption { return func(p *workflowParams) { p.workflowName = name From f6c8fc5a16dc2de880834acbad7d58f3233fe0ca Mon Sep 17 00:00:00 2001 From: maxdml Date: Fri, 1 Aug 2025 17:49:10 -0700 Subject: [PATCH 29/30] type check + nits --- dbos/workflow.go | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/dbos/workflow.go b/dbos/workflow.go index f9da8859..e321f506 100644 --- a/dbos/workflow.go +++ b/dbos/workflow.go @@ -692,7 +692,11 @@ func RunAsStep[P any, R any](ctx DBOSContext, fn GenericStepFunc[P, R], input P) // Type-erase the function based on its actual type typeErasedFn := StepFunc(func(ctx context.Context, input any) (any, error) { - return fn(ctx, input.(P)) + typedInput, ok := input.(P) + if !ok { + return nil, newStepExecutionError("", "", fmt.Sprintf("unexpected input type: expected %T, got %T", *new(P), input)) + } + return fn(ctx, typedInput) }) typeErasedStepNameToStepName[runtime.FuncForPC(reflect.ValueOf(typeErasedFn).Pointer()).Name()] = runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name() @@ -718,7 +722,7 @@ func RunAsStep[P any, R any](ctx DBOSContext, fn GenericStepFunc[P, R], input P) func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc, input any) (any, error) { // Get workflow state from context - wfState, ok := c.ctx.Value(workflowStateKey).(*workflowState) + wfState, ok := c.Value(workflowStateKey).(*workflowState) if !ok || wfState == nil { // TODO: try to print step name return nil, newStepExecutionError("", "", "workflow state not found in context: are you running this step within a workflow?") @@ -738,7 +742,7 @@ func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc, input any) (any, err // If within a step, just run the function directly if wfState.isWithinStep { - return fn(c.ctx, input) + return fn(c, input) } // Setup step state From d067655dc1152954783165967045f8a6b8e24df9 Mon Sep 17 00:00:00 2001 From: maxdml Date: Fri, 1 Aug 2025 17:58:22 -0700 Subject: [PATCH 30/30] simplify tests --- dbos/dbos_test.go | 39 ++++++++++++++++++++ dbos/initialize_test.go | 82 ----------------------------------------- 2 files changed, 39 insertions(+), 82 deletions(-) delete mode 100644 dbos/initialize_test.go diff --git a/dbos/dbos_test.go b/dbos/dbos_test.go index 10de8b14..440653b4 100644 --- a/dbos/dbos_test.go +++ b/dbos/dbos_test.go @@ -7,6 +7,45 @@ import ( func TestConfigValidationErrorTypes(t *testing.T) { databaseURL := getDatabaseURL() + t.Run("CreatesDBOSContext", func(t *testing.T) { + t.Setenv("DBOS__APPVERSION", "v1.0.0") + t.Setenv("DBOS__APPID", "test-app-id") + t.Setenv("DBOS__VMID", "test-executor-id") + ctx, err := NewDBOSContext(Config{ + DatabaseURL: databaseURL, + AppName: "test-initialize", + }) + if err != nil { + t.Fatalf("Failed to initialize DBOS: %v", err) + } + defer func() { + if ctx != nil { + ctx.Shutdown() + } + }() // Clean up executor + + if ctx == nil { + t.Fatal("Initialize returned nil executor") + } + + // Test that executor implements DBOSContext interface + var _ DBOSContext = ctx + + // Test that we can call methods on the executor + appVersion := ctx.GetApplicationVersion() + if appVersion != "v1.0.0" { + t.Fatal("GetApplicationVersion returned empty string") + } + executorID := ctx.GetExecutorID() + if executorID != "test-executor-id" { + t.Fatal("GetExecutorID returned empty string") + } + appID := ctx.GetApplicationID() + if appID != "test-app-id" { + t.Fatal("GetApplicationID returned empty string") + } + }) + t.Run("FailsWithoutAppName", func(t *testing.T) { config := Config{ DatabaseURL: databaseURL, diff --git a/dbos/initialize_test.go b/dbos/initialize_test.go deleted file mode 100644 index d15e8cf3..00000000 --- a/dbos/initialize_test.go +++ /dev/null @@ -1,82 +0,0 @@ -package dbos - -import ( - "testing" -) - -// TestInitializeReturnsExecutor verifies that our updated Initialize function works correctly -func TestInitializeReturnsExecutor(t *testing.T) { - databaseURL := getDatabaseURL() - - // Test that Initialize returns a DBOSExecutor - ctx, err := NewDBOSContext(Config{ - DatabaseURL: databaseURL, - AppName: "test-initialize", - }) - if err != nil { - t.Fatalf("Failed to initialize DBOS: %v", err) - } - defer func() { - if ctx != nil { - ctx.Shutdown() - } - }() // Clean up executor - - if ctx == nil { - t.Fatal("Initialize returned nil executor") - } - - // Test that executor implements DBOSContext interface - var _ DBOSContext = ctx - - // Test that we can call methods on the executor - appVersion := ctx.GetApplicationVersion() - if appVersion == "" { - t.Fatal("GetApplicationVersion returned empty string") - } - - scheduler := ctx.(*dbosContext).getWorkflowScheduler() - if scheduler == nil { - t.Fatal("getWorkflowScheduler returned nil") - } -} - -// TestWithWorkflowWithExecutor verifies that WithWorkflow works with an executor -func TestWithWorkflowWithExecutor(t *testing.T) { - ctx := setupDBOS(t) - - // Test workflow function - testWorkflow := func(ctx DBOSContext, input string) (string, error) { - return "hello " + input, nil - } - - // Test that RegisterWorkflow works with executor - RegisterWorkflow(ctx, testWorkflow) - - // Test executing the workflow - handle, err := RunAsWorkflow(ctx, testWorkflow, "world") - if err != nil { - t.Fatalf("Failed to execute workflow: %v", err) - } - - result, err := handle.GetResult() - if err != nil { - t.Fatalf("Failed to get workflow result: %v", err) - } - - expected := "hello world" - if result != expected { - t.Fatalf("Expected %q, got %q", expected, result) - } -} - -// TestSetupDBOSReturnsExecutor verifies that setupDBOS returns an executor -func TestSetupDBOSReturnsExecutor(t *testing.T) { - executor := setupDBOS(t) - - if executor == nil { - t.Fatal("setupDBOS returned nil executor") - } - - // Test succeeded - executor is valid -}