diff --git a/dbos/admin_server.go b/dbos/admin_server.go index b942c13a..c4b4b27d 100644 --- a/dbos/admin_server.go +++ b/dbos/admin_server.go @@ -25,7 +25,7 @@ type queueMetadata struct { RateLimit *RateLimiter `json:"rateLimit,omitempty"` } -func newAdminServer(port int) *adminServer { +func newAdminServer(ctx *dbosContext, port int) *adminServer { mux := http.NewServeMux() // Health endpoint @@ -50,7 +50,7 @@ func newAdminServer(port int) *adminServer { getLogger().Info("Recovering workflows for executors", "executors", executorIDs) - handles, err := recoverPendingWorkflows(r.Context(), executorIDs) + handles, err := recoverPendingWorkflows(ctx, executorIDs) if err != nil { getLogger().Error("Error recovering workflows", "error", err) http.Error(w, fmt.Sprintf("Recovery failed: %v", err), http.StatusInternalServerError) diff --git a/dbos/admin_server_test.go b/dbos/admin_server_test.go index 9f971ffa..5d11f8b1 100644 --- a/dbos/admin_server_test.go +++ b/dbos/admin_server_test.go @@ -11,27 +11,27 @@ import ( ) func TestAdminServer(t *testing.T) { - databaseURL := getDatabaseURL(t) + databaseURL := getDatabaseURL() t.Run("Admin server is not started by default", func(t *testing.T) { - // Ensure clean state - Shutdown() - err := Initialize(Config{ + ctx, err := NewDBOSContext(Config{ DatabaseURL: databaseURL, AppName: "test-app", }) if err != nil { t.Skipf("Failed to initialize DBOS: %v", err) } - err = Launch() + err = ctx.Launch() if err != nil { t.Skipf("Failed to initialize DBOS: %v", err) } // Ensure cleanup defer func() { - Shutdown() + if ctx != nil { + ctx.Shutdown() + } }() // Give time for any startup processes @@ -45,20 +45,19 @@ func TestAdminServer(t *testing.T) { } // Verify the DBOS executor doesn't have an admin server instance - if dbos == nil { + if ctx == nil { t.Fatal("Expected DBOS instance to be created") } - if dbos.adminServer != nil { + exec := ctx.(*dbosContext) + if exec.adminServer != nil { t.Error("Expected admin server to be nil when not configured") } }) t.Run("Admin server endpoints", func(t *testing.T) { - Shutdown() - // Launch DBOS with admin server once for all endpoint tests - err := Initialize(Config{ + ctx, err := NewDBOSContext(Config{ DatabaseURL: databaseURL, AppName: "test-app", AdminServer: true, @@ -66,25 +65,28 @@ func TestAdminServer(t *testing.T) { if err != nil { t.Skipf("Failed to initialize DBOS with admin server: %v", err) } - err = Launch() + err = ctx.Launch() if err != nil { t.Skipf("Failed to initialize DBOS with admin server: %v", err) } // Ensure cleanup defer func() { - Shutdown() + if ctx != nil { + ctx.Shutdown() + } }() // Give the server a moment to start time.Sleep(100 * time.Millisecond) // Verify the DBOS executor has an admin server instance - if dbos == nil { + if ctx == nil { t.Fatal("Expected DBOS instance to be created") } - if dbos.adminServer == nil { + exec := ctx.(*dbosContext) + if exec.adminServer == nil { t.Fatal("Expected admin server to be created in DBOS instance") } diff --git a/dbos/dbos.go b/dbos/dbos.go index 60bd8c44..f2090ef2 100644 --- a/dbos/dbos.go +++ b/dbos/dbos.go @@ -19,12 +19,10 @@ var ( _DEFAULT_ADMIN_SERVER_PORT = 3001 ) -var workflowScheduler *cron.Cron // Global because accessed during workflow registration before the dbos singleton is initialized - var logger *slog.Logger // Global because accessed everywhere inside the library func getLogger() *slog.Logger { - if dbos == nil || logger == nil { + if logger == nil { return slog.New(slog.NewTextHandler(os.Stderr, nil)) } return logger @@ -62,35 +60,133 @@ func processConfig(inputConfig *Config) (*Config, error) { return dbosConfig, nil } -var dbos *executor // DBOS singleton instance +type DBOSContext interface { + context.Context + + // Context Lifecycle + Launch() error + Shutdown() + + // Workflow operations + RunAsStep(_ DBOSContext, fn StepFunc, input any) (any, error) + RunAsWorkflow(_ DBOSContext, fn WorkflowFunc, input any, opts ...WorkflowOption) (WorkflowHandle[any], error) + Send(_ DBOSContext, input WorkflowSendInputInternal) error + Recv(_ DBOSContext, input WorkflowRecvInput) (any, error) + SetEvent(_ DBOSContext, input WorkflowSetEventInput) error + GetEvent(_ DBOSContext, input WorkflowGetEventInput) (any, error) + Sleep(duration time.Duration) (time.Duration, error) + GetWorkflowID() (string, error) + + // Workflow management + RetrieveWorkflow(_ DBOSContext, workflowID string) (WorkflowHandle[any], error) + + // Accessors + GetApplicationVersion() string + GetExecutorID() string + GetApplicationID() string +} + +type dbosContext struct { + ctx context.Context + + launched bool + + systemDB SystemDatabase + adminServer *adminServer + config *Config -type executor struct { - systemDB SystemDatabase + // Queue runner context and cancel function queueRunnerCtx context.Context queueRunnerCancelFunc context.CancelFunc queueRunnerDone chan struct{} - adminServer *adminServer - config *Config - applicationVersion string - applicationID string - executorID string - workflowsWg *sync.WaitGroup + + // Application metadata + applicationVersion string + applicationID string + executorID string + + // Wait group for workflow goroutines + workflowsWg *sync.WaitGroup + + // Workflow registry + workflowRegistry map[string]workflowRegistryEntry + workflowRegMutex *sync.RWMutex + + // Workflow scheduler + workflowScheduler *cron.Cron +} + +// Implement contex.Context interface methods +func (c *dbosContext) Deadline() (deadline time.Time, ok bool) { + return c.ctx.Deadline() +} + +func (c *dbosContext) Done() <-chan struct{} { + return c.ctx.Done() +} + +func (c *dbosContext) Err() error { + return c.ctx.Err() +} + +func (c *dbosContext) Value(key any) any { + return c.ctx.Value(key) } -func Initialize(inputConfig Config) error { - if dbos != nil { - fmt.Println("warning: DBOS instance already initialized, skipping re-initialization") - return newInitializationError("DBOS already initialized") +// Create a new context +// This is intended for workflow contexts and step contexts +// Hence we only set the relevant fields +func WithValue(ctx DBOSContext, key, val any) DBOSContext { + if ctx == nil { + return nil } + // Will do nothing if the concrete type is not dbosContext + if dbosCtx, ok := ctx.(*dbosContext); ok { + return &dbosContext{ + ctx: context.WithValue(dbosCtx.ctx, key, val), + systemDB: dbosCtx.systemDB, + workflowsWg: dbosCtx.workflowsWg, + workflowRegistry: dbosCtx.workflowRegistry, + workflowRegMutex: dbosCtx.workflowRegMutex, + applicationVersion: dbosCtx.applicationVersion, + executorID: dbosCtx.executorID, + applicationID: dbosCtx.applicationID, + } + } + return nil +} + +func (c *dbosContext) getWorkflowScheduler() *cron.Cron { + if c.workflowScheduler == nil { + c.workflowScheduler = cron.New(cron.WithSeconds()) + } + return c.workflowScheduler +} + +func (c *dbosContext) GetApplicationVersion() string { + return c.applicationVersion +} + +func (c *dbosContext) GetExecutorID() string { + return c.executorID +} + +func (c *dbosContext) GetApplicationID() string { + return c.applicationID +} - initExecutor := &executor{ - workflowsWg: &sync.WaitGroup{}, +func NewDBOSContext(inputConfig Config) (DBOSContext, error) { + initExecutor := &dbosContext{ + workflowsWg: &sync.WaitGroup{}, + ctx: context.Background(), + workflowRegistry: make(map[string]workflowRegistryEntry), + workflowRegMutex: &sync.RWMutex{}, } - // Load & process the configuration + // Load and process the configuration config, err := processConfig(&inputConfig) if err != nil { - return newInitializationError(err.Error()) + return nil, newInitializationError(err.Error()) } initExecutor.config = config @@ -119,57 +215,56 @@ func Initialize(inputConfig Config) error { // Create the system database systemDB, err := NewSystemDatabase(config.DatabaseURL) if err != nil { - return newInitializationError(fmt.Sprintf("failed to create system database: %v", err)) + return nil, newInitializationError(fmt.Sprintf("failed to create system database: %v", err)) } initExecutor.systemDB = systemDB logger.Info("System database initialized") - // Set the global dbos instance - dbos = initExecutor - - return nil + return initExecutor, nil } -func Launch() error { - if dbos == nil { - return newInitializationError("DBOS instance not initialized, call Initialize first") +func (c *dbosContext) Launch() error { + if c.launched { + return newInitializationError("DBOS is already launched") } + // Start the system database - dbos.systemDB.Launch(context.Background()) + c.systemDB.Launch(context.Background()) // Start the admin server if configured - if dbos.config.AdminServer { - adminServer := newAdminServer(_DEFAULT_ADMIN_SERVER_PORT) + if c.config.AdminServer { + adminServer := newAdminServer(c, _DEFAULT_ADMIN_SERVER_PORT) err := adminServer.Start() if err != nil { logger.Error("Failed to start admin server", "error", err) return newInitializationError(fmt.Sprintf("failed to start admin server: %v", err)) } logger.Info("Admin server started", "port", _DEFAULT_ADMIN_SERVER_PORT) - dbos.adminServer = adminServer + c.adminServer = adminServer } // Create context with cancel function for queue runner - ctx, cancel := context.WithCancel(context.Background()) - dbos.queueRunnerCtx = ctx - dbos.queueRunnerCancelFunc = cancel - dbos.queueRunnerDone = make(chan struct{}) + // FIXME: cancellation now has to go through the DBOSContext + ctx, cancel := context.WithCancel(c.ctx) + c.queueRunnerCtx = ctx + c.queueRunnerCancelFunc = cancel + c.queueRunnerDone = make(chan struct{}) // Start the queue runner in a goroutine go func() { - defer close(dbos.queueRunnerDone) - queueRunner(ctx) + defer close(c.queueRunnerDone) + queueRunner(c) }() logger.Info("Queue runner started") // Start the workflow scheduler if it has been initialized - if workflowScheduler != nil { - workflowScheduler.Start() + if c.workflowScheduler != nil { + c.workflowScheduler.Start() logger.Info("Workflow scheduler started") } // Run a round of recovery on the local executor - recoveryHandles, err := recoverPendingWorkflows(context.Background(), []string{dbos.executorID}) // XXX maybe use the queue runner context here to allow Shutdown to cancel it? + recoveryHandles, err := recoverPendingWorkflows(c, []string{c.executorID}) if err != nil { return newInitializationError(fmt.Sprintf("failed to recover pending workflows during launch: %v", err)) } @@ -177,29 +272,33 @@ func Launch() error { logger.Info("Recovered pending workflows", "count", len(recoveryHandles)) } - logger.Info("DBOS initialized", "app_version", dbos.applicationVersion, "executor_id", dbos.executorID) + logger.Info("DBOS initialized", "app_version", c.applicationVersion, "executor_id", c.executorID) + c.launched = true return nil } -func Shutdown() { - if dbos == nil { - fmt.Println("DBOS instance is nil, cannot shutdown") +func (c *dbosContext) Shutdown() { + if !c.launched { + logger.Warn("DBOS is not launched, nothing to shutdown") return } - // XXX is there a way to ensure all workflows goroutine are done before closing? - dbos.workflowsWg.Wait() + // Wait for all workflows to finish + getLogger().Info("Waiting for all workflows to finish") + c.workflowsWg.Wait() // Cancel the context to stop the queue runner - if dbos.queueRunnerCancelFunc != nil { - dbos.queueRunnerCancelFunc() + if c.queueRunnerCancelFunc != nil { + getLogger().Info("Stopping queue runner") + c.queueRunnerCancelFunc() // Wait for queue runner to finish - <-dbos.queueRunnerDone + <-c.queueRunnerDone getLogger().Info("Queue runner stopped") } - if workflowScheduler != nil { - ctx := workflowScheduler.Stop() + if c.workflowScheduler != nil { + getLogger().Info("Stopping workflow scheduler") + ctx := c.workflowScheduler.Stop() // Wait for all running jobs to complete with 5-second timeout timeoutCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() @@ -212,25 +311,26 @@ func Shutdown() { } } - if dbos.systemDB != nil { - dbos.systemDB.Shutdown() - dbos.systemDB = nil + if c.systemDB != nil { + getLogger().Info("Shutting down system database") + c.systemDB.Shutdown() + c.systemDB = nil } - if dbos.adminServer != nil { - err := dbos.adminServer.Shutdown() + if c.adminServer != nil { + getLogger().Info("Shutting down admin server") + err := c.adminServer.Shutdown() if err != nil { getLogger().Error("Failed to shutdown admin server", "error", err) } else { getLogger().Info("Admin server shutdown complete") } - dbos.adminServer = nil + c.adminServer = nil } if logger != nil { logger = nil } - dbos = nil } func GetBinaryHash() (string, error) { diff --git a/dbos/dbos_test.go b/dbos/dbos_test.go index 192b4c8c..440653b4 100644 --- a/dbos/dbos_test.go +++ b/dbos/dbos_test.go @@ -5,14 +5,53 @@ import ( ) func TestConfigValidationErrorTypes(t *testing.T) { - databaseURL := getDatabaseURL(t) + databaseURL := getDatabaseURL() + + t.Run("CreatesDBOSContext", func(t *testing.T) { + t.Setenv("DBOS__APPVERSION", "v1.0.0") + t.Setenv("DBOS__APPID", "test-app-id") + t.Setenv("DBOS__VMID", "test-executor-id") + ctx, err := NewDBOSContext(Config{ + DatabaseURL: databaseURL, + AppName: "test-initialize", + }) + if err != nil { + t.Fatalf("Failed to initialize DBOS: %v", err) + } + defer func() { + if ctx != nil { + ctx.Shutdown() + } + }() // Clean up executor + + if ctx == nil { + t.Fatal("Initialize returned nil executor") + } + + // Test that executor implements DBOSContext interface + var _ DBOSContext = ctx + + // Test that we can call methods on the executor + appVersion := ctx.GetApplicationVersion() + if appVersion != "v1.0.0" { + t.Fatal("GetApplicationVersion returned empty string") + } + executorID := ctx.GetExecutorID() + if executorID != "test-executor-id" { + t.Fatal("GetExecutorID returned empty string") + } + appID := ctx.GetApplicationID() + if appID != "test-app-id" { + t.Fatal("GetApplicationID returned empty string") + } + }) t.Run("FailsWithoutAppName", func(t *testing.T) { config := Config{ DatabaseURL: databaseURL, } - err := Initialize(config) + _, err := NewDBOSContext(config) if err == nil { t.Fatal("expected error when app name is missing, but got none") } @@ -37,7 +76,7 @@ func TestConfigValidationErrorTypes(t *testing.T) { AppName: "test-app", } - err := Initialize(config) + _, err := NewDBOSContext(config) if err == nil { t.Fatal("expected error when database URL is missing, but got none") } diff --git a/dbos/logger_test.go b/dbos/logger_test.go index 2b7eeebb..72a9acf0 100644 --- a/dbos/logger_test.go +++ b/dbos/logger_test.go @@ -8,22 +8,24 @@ import ( ) func TestLogger(t *testing.T) { - databaseURL := getDatabaseURL(t) + databaseURL := getDatabaseURL() t.Run("Default logger", func(t *testing.T) { - err := Initialize(Config{ + dbosCtx, err := NewDBOSContext(Config{ DatabaseURL: databaseURL, AppName: "test-app", }) // Create executor with default logger if err != nil { t.Fatalf("Failed to create executor with default logger: %v", err) } - err = Launch() + err = dbosCtx.Launch() if err != nil { t.Fatalf("Failed to launch with default logger: %v", err) } t.Cleanup(func() { - Shutdown() + if dbosCtx != nil { + dbosCtx.Shutdown() + } }) if logger == nil { @@ -45,7 +47,7 @@ func TestLogger(t *testing.T) { // Add some context to the slog logger slogLogger = slogLogger.With("service", "dbos-test", "environment", "test") - err := Initialize(Config{ + dbosCtx, err := NewDBOSContext(Config{ DatabaseURL: databaseURL, AppName: "test-app", Logger: slogLogger, @@ -53,12 +55,14 @@ func TestLogger(t *testing.T) { if err != nil { t.Fatalf("Failed to create executor with custom logger: %v", err) } - err = Launch() + err = dbosCtx.Launch() if err != nil { t.Fatalf("Failed to launch with custom logger: %v", err) } t.Cleanup(func() { - Shutdown() + if dbosCtx != nil { + dbosCtx.Shutdown() + } }) if logger == nil { diff --git a/dbos/queue.go b/dbos/queue.go index 6cc6ec05..62042151 100644 --- a/dbos/queue.go +++ b/dbos/queue.go @@ -2,7 +2,6 @@ package dbos import ( "bytes" - "context" "encoding/base64" "encoding/gob" "math" @@ -73,10 +72,7 @@ func WithMaxTasksPerIteration(maxTasks int) queueOption { // NewWorkflowQueue creates a new workflow queue with optional configuration func NewWorkflowQueue(name string, options ...queueOption) WorkflowQueue { - if dbos != nil { - getLogger().Warn("NewWorkflowQueue called after DBOS initialization, dynamic registration is not supported") - return WorkflowQueue{} - } + // TODO: Add runtime check for post-initialization registration if needed if _, exists := workflowQueueRegistry[name]; exists { panic(newConflictingRegistrationError(name)) } @@ -102,7 +98,7 @@ func NewWorkflowQueue(name string, options ...queueOption) WorkflowQueue { return q } -func queueRunner(ctx context.Context) { +func queueRunner(ctx *dbosContext) { const ( baseInterval = 1.0 // Base interval in seconds minInterval = 1.0 // Minimum polling interval in seconds @@ -122,7 +118,7 @@ func queueRunner(ctx context.Context) { for queueName, queue := range workflowQueueRegistry { getLogger().Debug("Processing queue", "queue_name", queueName) // Call DequeueWorkflows for each queue - dequeuedWorkflows, err := dbos.systemDB.DequeueWorkflows(ctx, queue) + dequeuedWorkflows, err := ctx.systemDB.DequeueWorkflows(ctx.queueRunnerCtx, queue, ctx.executorID, ctx.applicationVersion) if err != nil { if pgErr, ok := err.(*pgconn.PgError); ok { switch pgErr.Code { @@ -143,7 +139,7 @@ func queueRunner(ctx context.Context) { } for _, workflow := range dequeuedWorkflows { // Find the workflow in the registry - registeredWorkflow, exists := registry[workflow.name] + registeredWorkflow, exists := ctx.workflowRegistry[workflow.name] if !exists { getLogger().Error("workflow function not found in registry", "workflow_name", workflow.name) continue @@ -165,6 +161,7 @@ func queueRunner(ctx context.Context) { } } + // XXX this demonstrate why contexts cannot be used globally -- the task does not inherit the context used in the program that enqueued it _, err := registeredWorkflow.wrappedFunction(ctx, input, WithWorkflowID(workflow.id)) if err != nil { getLogger().Error("Error running queued workflow", "error", err) @@ -187,7 +184,7 @@ func queueRunner(ctx context.Context) { // Sleep with jittered interval, but allow early exit on context cancellation select { - case <-ctx.Done(): + case <-ctx.queueRunnerCtx.Done(): getLogger().Info("Queue runner stopping due to context cancellation") return case <-time.After(sleepDuration): diff --git a/dbos/queues_test.go b/dbos/queues_test.go index 74a11c9c..488dbca2 100644 --- a/dbos/queues_test.go +++ b/dbos/queues_test.go @@ -27,24 +27,16 @@ This suite tests */ var ( - queue = NewWorkflowQueue("test-queue") - queueWf = WithWorkflow(queueWorkflow) - queueWfWithChild = WithWorkflow(queueWorkflowWithChild) - queueWfThatEnqueues = WithWorkflow(queueWorkflowThatEnqueues) + queue = NewWorkflowQueue("test-queue") // Variables for successive enqueue test - dlqEnqueueQueue = NewWorkflowQueue("test-successive-enqueue-queue") - dlqStartEvent = NewEvent() - dlqCompleteEvent = NewEvent() - dlqMaxRetries = 10 - enqueueWorkflowDLQ = WithWorkflow(func(ctx context.Context, input string) (string, error) { - dlqStartEvent.Set() - dlqCompleteEvent.Wait() - return input, nil - }, WithMaxRetries(dlqMaxRetries)) + dlqEnqueueQueue = NewWorkflowQueue("test-successive-enqueue-queue") + dlqStartEvent = NewEvent() + dlqCompleteEvent = NewEvent() + dlqMaxRetries = 10 ) -func queueWorkflow(ctx context.Context, input string) (string, error) { +func queueWorkflow(ctx DBOSContext, input string) (string, error) { step1, err := RunAsStep(ctx, queueStep, input) if err != nil { return "", fmt.Errorf("failed to run step: %v", err) @@ -56,43 +48,62 @@ func queueStep(ctx context.Context, input string) (string, error) { return input, nil } -func queueWorkflowWithChild(ctx context.Context, input string) (string, error) { - // Start a child workflow - childHandle, err := queueWf(ctx, input+"-child") - if err != nil { - return "", fmt.Errorf("failed to start child workflow: %v", err) - } +func TestWorkflowQueues(t *testing.T) { + dbosCtx := setupDBOS(t) - // Get result from child workflow - childResult, err := childHandle.GetResult(ctx) - if err != nil { - return "", fmt.Errorf("failed to get child result: %v", err) - } + // Register workflows with dbosContext + RegisterWorkflow(dbosCtx, queueWorkflow) - return childResult, nil -} + // Create workflow with child that can call the main workflow + queueWorkflowWithChild := func(ctx DBOSContext, input string) (string, error) { + // Start a child workflow + childHandle, err := RunAsWorkflow(ctx, queueWorkflow, input+"-child") + if err != nil { + return "", fmt.Errorf("failed to start child workflow: %v", err) + } -func queueWorkflowThatEnqueues(ctx context.Context, input string) (string, error) { - // Enqueue another workflow to the same queue - enqueuedHandle, err := queueWf(ctx, input+"-enqueued", WithQueue(queue.name)) - if err != nil { - return "", fmt.Errorf("failed to enqueue workflow: %v", err) + // Get result from child workflow + childResult, err := childHandle.GetResult() + if err != nil { + return "", fmt.Errorf("failed to get child result: %v", err) + } + + return childResult, nil } + RegisterWorkflow(dbosCtx, queueWorkflowWithChild) - // Get result from the enqueued workflow - enqueuedResult, err := enqueuedHandle.GetResult(ctx) - if err != nil { - return "", fmt.Errorf("failed to get enqueued workflow result: %v", err) + // Create workflow that enqueues another workflow + queueWorkflowThatEnqueues := func(ctx DBOSContext, input string) (string, error) { + // Enqueue another workflow to the same queue + enqueuedHandle, err := RunAsWorkflow(ctx, queueWorkflow, input+"-enqueued", WithQueue(queue.name)) + if err != nil { + return "", fmt.Errorf("failed to enqueue workflow: %v", err) + } + + // Get result from the enqueued workflow + enqueuedResult, err := enqueuedHandle.GetResult() + if err != nil { + return "", fmt.Errorf("failed to get enqueued workflow result: %v", err) + } + + return enqueuedResult, nil } + RegisterWorkflow(dbosCtx, queueWorkflowThatEnqueues) - return enqueuedResult, nil -} + enqueueWorkflowDLQ := func(ctx DBOSContext, input string) (string, error) { + dlqStartEvent.Set() + dlqCompleteEvent.Wait() + return input, nil + } + RegisterWorkflow(dbosCtx, enqueueWorkflowDLQ, WithMaxRetries(dlqMaxRetries)) -func TestWorkflowQueues(t *testing.T) { - setupDBOS(t) + err := dbosCtx.Launch() + if err != nil { + t.Fatalf("failed to launch DBOS instance: %v", err) + } t.Run("EnqueueWorkflow", func(t *testing.T) { - handle, err := queueWf(context.Background(), "test-input", WithQueue(queue.name)) + handle, err := RunAsWorkflow(dbosCtx, queueWorkflow, "test-input", WithQueue(queue.name)) if err != nil { t.Fatalf("failed to enqueue workflow: %v", err) } @@ -102,7 +113,7 @@ func TestWorkflowQueues(t *testing.T) { t.Fatalf("expected handle to be of type workflowPollingHandle, got %T", handle) } - res, err := handle.GetResult(context.Background()) + res, err := handle.GetResult() if err != nil { t.Fatalf("expected no error but got: %v", err) } @@ -110,18 +121,18 @@ func TestWorkflowQueues(t *testing.T) { t.Fatalf("expected workflow result to be 'test-input', got %v", res) } - if !queueEntriesAreCleanedUp() { + if !queueEntriesAreCleanedUp(dbosCtx) { t.Fatal("expected queue entries to be cleaned up after global concurrency test") } }) t.Run("EnqueuedWorkflowStartsChildWorkflow", func(t *testing.T) { - handle, err := queueWfWithChild(context.Background(), "test-input", WithQueue(queue.name)) + handle, err := RunAsWorkflow(dbosCtx, queueWorkflowWithChild, "test-input", WithQueue(queue.name)) if err != nil { t.Fatalf("failed to enqueue workflow with child: %v", err) } - res, err := handle.GetResult(context.Background()) + res, err := handle.GetResult() if err != nil { t.Fatalf("expected no error but got: %v", err) } @@ -132,18 +143,18 @@ func TestWorkflowQueues(t *testing.T) { t.Fatalf("expected workflow result to be '%s', got %v", expectedResult, res) } - if !queueEntriesAreCleanedUp() { + if !queueEntriesAreCleanedUp(dbosCtx) { t.Fatal("expected queue entries to be cleaned up after global concurrency test") } }) t.Run("WorkflowEnqueuesAnotherWorkflow", func(t *testing.T) { - handle, err := queueWfThatEnqueues(context.Background(), "test-input", WithQueue(queue.name)) + handle, err := RunAsWorkflow(dbosCtx, queueWorkflowThatEnqueues, "test-input", WithQueue(queue.name)) if err != nil { t.Fatalf("failed to enqueue workflow that enqueues another workflow: %v", err) } - res, err := handle.GetResult(context.Background()) + res, err := handle.GetResult() if err != nil { t.Fatalf("expected no error but got: %v", err) } @@ -154,23 +165,25 @@ func TestWorkflowQueues(t *testing.T) { t.Fatalf("expected workflow result to be '%s', got %v", expectedResult, res) } - if !queueEntriesAreCleanedUp() { + if !queueEntriesAreCleanedUp(dbosCtx) { t.Fatal("expected queue entries to be cleaned up after global concurrency test") } }) + /* TODO: we will move queue registry in the new interface in a subsequent PR t.Run("DynamicRegistration", func(t *testing.T) { q := NewWorkflowQueue("dynamic-queue") if len(q.name) > 0 { t.Fatalf("expected nil queue for dynamic registration after DBOS initialization, got %v", q) } }) + */ t.Run("QueueWorkflowDLQ", func(t *testing.T) { workflowID := "blocking-workflow-test" // Enqueue the workflow for the first time - originalHandle, err := enqueueWorkflowDLQ(context.Background(), "test-input", WithQueue(dlqEnqueueQueue.name), WithWorkflowID(workflowID)) + originalHandle, err := RunAsWorkflow(dbosCtx, enqueueWorkflowDLQ, "test-input", WithQueue(dlqEnqueueQueue.name), WithWorkflowID(workflowID)) if err != nil { t.Fatalf("failed to enqueue blocking workflow: %v", err) } @@ -181,7 +194,7 @@ func TestWorkflowQueues(t *testing.T) { // Try to enqueue the same workflow more times for i := range dlqMaxRetries * 2 { - _, err := enqueueWorkflowDLQ(context.Background(), "test-input", WithQueue(dlqEnqueueQueue.name), WithWorkflowID(workflowID)) + _, err := RunAsWorkflow(dbosCtx, enqueueWorkflowDLQ, "test-input", WithQueue(dlqEnqueueQueue.name), WithWorkflowID(workflowID)) if err != nil { t.Fatalf("failed to enqueue workflow attempt %d: %v", i+1, err) } @@ -201,7 +214,7 @@ func TestWorkflowQueues(t *testing.T) { // Check that the workflow hits DLQ after re-running max retries handles := make([]WorkflowHandle[any], 0, dlqMaxRetries+1) for i := range dlqMaxRetries { - recoveryHandles, err := recoverPendingWorkflows(context.Background(), []string{"local"}) + recoveryHandles, err := recoverPendingWorkflows(dbosCtx.(*dbosContext), []string{"local"}) if err != nil { t.Fatalf("failed to recover pending workflows: %v", err) } @@ -227,7 +240,7 @@ func TestWorkflowQueues(t *testing.T) { // Check the workflow completes dlqCompleteEvent.Set() for _, handle := range handles { - result, err := handle.GetResult(context.Background()) + result, err := handle.GetResult() if err != nil { t.Fatalf("failed to get result from recovered workflow handle: %v", err) } @@ -236,7 +249,7 @@ func TestWorkflowQueues(t *testing.T) { } } - if !queueEntriesAreCleanedUp() { + if !queueEntriesAreCleanedUp(dbosCtx) { t.Fatal("expected queue entries to be cleaned up after successive enqueues test") } }) @@ -248,18 +261,25 @@ var ( recoveryStepCounter = 0 recoveryStepEvents = make([]*Event, 5) // 5 queued steps recoveryEvent = NewEvent() +) + +func TestQueueRecovery(t *testing.T) { + dbosCtx := setupDBOS(t) - recoveryStepWorkflow = WithWorkflow(func(ctx context.Context, i int) (int, error) { + // Create workflows with dbosContext + + recoveryStepWorkflowFunc := func(ctx DBOSContext, i int) (int, error) { recoveryStepCounter++ recoveryStepEvents[i].Set() recoveryEvent.Wait() return i, nil - }) + } + RegisterWorkflow(dbosCtx, recoveryStepWorkflowFunc) - recoveryWorkflow = WithWorkflow(func(ctx context.Context, input string) ([]int, error) { + recoveryWorkflowFunc := func(ctx DBOSContext, input string) ([]int, error) { handles := make([]WorkflowHandle[int], 0, 5) // 5 queued steps for i := range 5 { - handle, err := recoveryStepWorkflow(ctx, i, WithQueue(recoveryQueue.name)) + handle, err := RunAsWorkflow(ctx, recoveryStepWorkflowFunc, i, WithQueue(recoveryQueue.name)) if err != nil { return nil, fmt.Errorf("failed to enqueue step %d: %v", i, err) } @@ -268,18 +288,20 @@ var ( results := make([]int, 0, 5) for _, handle := range handles { - result, err := handle.GetResult(ctx) + result, err := handle.GetResult() if err != nil { return nil, fmt.Errorf("failed to get result for handle: %v", err) } results = append(results, result) } return results, nil - }) -) + } + RegisterWorkflow(dbosCtx, recoveryWorkflowFunc) -func TestQueueRecovery(t *testing.T) { - setupDBOS(t) + err := dbosCtx.Launch() + if err != nil { + t.Fatalf("failed to launch DBOS instance: %v", err) + } queuedSteps := 5 @@ -290,7 +312,7 @@ func TestQueueRecovery(t *testing.T) { wfid := uuid.NewString() // Start the workflow. Wait for all steps to start. Verify that they started. - handle, err := recoveryWorkflow(context.Background(), "", WithWorkflowID(wfid)) + handle, err := RunAsWorkflow(dbosCtx, recoveryWorkflowFunc, "", WithWorkflowID(wfid)) if err != nil { t.Fatalf("failed to start workflow: %v", err) } @@ -305,7 +327,7 @@ func TestQueueRecovery(t *testing.T) { } // Recover the workflow, then resume it. - recoveryHandles, err := recoverPendingWorkflows(context.Background(), []string{"local"}) + recoveryHandles, err := recoverPendingWorkflows(dbosCtx.(*dbosContext), []string{"local"}) if err != nil { t.Fatalf("failed to recover pending workflows: %v", err) } @@ -322,7 +344,7 @@ func TestQueueRecovery(t *testing.T) { for _, h := range recoveryHandles { if h.GetWorkflowID() == wfid { // Root workflow case - result, err := h.GetResult(context.Background()) + result, err := h.GetResult() if err != nil { t.Fatalf("failed to get result from recovered root workflow handle: %v", err) } @@ -337,7 +359,7 @@ func TestQueueRecovery(t *testing.T) { } } - result, err := handle.GetResult(context.Background()) + result, err := handle.GetResult() if err != nil { t.Fatalf("failed to get result from original handle: %v", err) } @@ -351,11 +373,11 @@ func TestQueueRecovery(t *testing.T) { } // Rerun the workflow. Because each step is complete, none should start again. - rerunHandle, err := recoveryWorkflow(context.Background(), "test-input", WithWorkflowID(wfid)) + rerunHandle, err := RunAsWorkflow(dbosCtx, recoveryWorkflowFunc, "test-input", WithWorkflowID(wfid)) if err != nil { t.Fatalf("failed to rerun workflow: %v", err) } - rerunResult, err := rerunHandle.GetResult(context.Background()) + rerunResult, err := rerunHandle.GetResult() if err != nil { t.Fatalf("failed to get result from rerun handle: %v", err) } @@ -367,17 +389,23 @@ func TestQueueRecovery(t *testing.T) { t.Fatalf("expected recoveryStepCounter to remain %d, got %d", queuedSteps*2, recoveryStepCounter) } - if !queueEntriesAreCleanedUp() { + if !queueEntriesAreCleanedUp(dbosCtx) { t.Fatal("expected queue entries to be cleaned up after global concurrency test") } } var ( - globalConcurrencyQueue = NewWorkflowQueue("test-global-concurrency-queue", WithGlobalConcurrency(1)) - workflowEvent1 = NewEvent() - workflowEvent2 = NewEvent() - workflowDoneEvent = NewEvent() - globalConcurrencyWorkflow = WithWorkflow(func(ctx context.Context, input string) (string, error) { + globalConcurrencyQueue = NewWorkflowQueue("test-global-concurrency-queue", WithGlobalConcurrency(1)) + workflowEvent1 = NewEvent() + workflowEvent2 = NewEvent() + workflowDoneEvent = NewEvent() +) + +func TestGlobalConcurrency(t *testing.T) { + dbosCtx := setupDBOS(t) + + // Create workflow with dbosContext + globalConcurrencyWorkflowFunc := func(ctx DBOSContext, input string) (string, error) { switch input { case "workflow1": workflowEvent1.Set() @@ -386,19 +414,21 @@ var ( workflowEvent2.Set() } return input, nil - }) -) + } + RegisterWorkflow(dbosCtx, globalConcurrencyWorkflowFunc) -func TestGlobalConcurrency(t *testing.T) { - setupDBOS(t) + err := dbosCtx.Launch() + if err != nil { + t.Fatalf("failed to launch DBOS instance: %v", err) + } // Enqueue two workflows - handle1, err := globalConcurrencyWorkflow(context.Background(), "workflow1", WithQueue(globalConcurrencyQueue.name)) + handle1, err := RunAsWorkflow(dbosCtx, globalConcurrencyWorkflowFunc, "workflow1", WithQueue(globalConcurrencyQueue.name)) if err != nil { t.Fatalf("failed to enqueue workflow1: %v", err) } - handle2, err := globalConcurrencyWorkflow(context.Background(), "workflow2", WithQueue(globalConcurrencyQueue.name)) + handle2, err := RunAsWorkflow(dbosCtx, globalConcurrencyWorkflowFunc, "workflow2", WithQueue(globalConcurrencyQueue.name)) if err != nil { t.Fatalf("failed to enqueue workflow2: %v", err) } @@ -422,7 +452,7 @@ func TestGlobalConcurrency(t *testing.T) { // Allow the first workflow to complete workflowDoneEvent.Set() - result1, err := handle1.GetResult(context.Background()) + result1, err := handle1.GetResult() if err != nil { t.Fatalf("failed to get result from workflow1: %v", err) } @@ -433,14 +463,14 @@ func TestGlobalConcurrency(t *testing.T) { // Wait for the second workflow to start workflowEvent2.Wait() - result2, err := handle2.GetResult(context.Background()) + result2, err := handle2.GetResult() if err != nil { t.Fatalf("failed to get result from workflow2: %v", err) } if result2 != "workflow2" { t.Fatalf("expected result from workflow2 to be 'workflow2', got %v", result2) } - if !queueEntriesAreCleanedUp() { + if !queueEntriesAreCleanedUp(dbosCtx) { t.Fatal("expected queue entries to be cleaned up after global concurrency test") } } @@ -459,31 +489,39 @@ var ( NewEvent(), NewEvent(), } - blockingWf = WithWorkflow(func(ctx context.Context, i int) (int, error) { +) + +func TestWorkerConcurrency(t *testing.T) { + dbosCtx := setupDBOS(t) + + // Create workflow with dbosContext + blockingWfFunc := func(ctx DBOSContext, i int) (int, error) { // Simulate a blocking operation startEvents[i].Set() completeEvents[i].Wait() return i, nil - }) -) + } + RegisterWorkflow(dbosCtx, blockingWfFunc) -func TestWorkerConcurrency(t *testing.T) { - setupDBOS(t) + err := dbosCtx.Launch() + if err != nil { + t.Fatalf("failed to launch DBOS instance: %v", err) + } // First enqueue four blocking workflows - handle1, err := blockingWf(context.Background(), 0, WithQueue(workerConcurrencyQueue.name), WithWorkflowID("worker-cc-wf-1")) + handle1, err := RunAsWorkflow(dbosCtx, blockingWfFunc, 0, WithQueue(workerConcurrencyQueue.name), WithWorkflowID("worker-cc-wf-1")) if err != nil { t.Fatalf("failed to enqueue blocking workflow 1: %v", err) } - handle2, err := blockingWf(context.Background(), 1, WithQueue(workerConcurrencyQueue.name), WithWorkflowID("worker-cc-wf-2")) + handle2, err := RunAsWorkflow(dbosCtx, blockingWfFunc, 1, WithQueue(workerConcurrencyQueue.name), WithWorkflowID("worker-cc-wf-2")) if err != nil { t.Fatalf("failed to enqueue blocking workflow 2: %v", err) } - _, err = blockingWf(context.Background(), 2, WithQueue(workerConcurrencyQueue.name), WithWorkflowID("worker-cc-wf-3")) + _, err = RunAsWorkflow(dbosCtx, blockingWfFunc, 2, WithQueue(workerConcurrencyQueue.name), WithWorkflowID("worker-cc-wf-3")) if err != nil { t.Fatalf("failed to enqueue blocking workflow 3: %v", err) } - _, err = blockingWf(context.Background(), 3, WithQueue(workerConcurrencyQueue.name), WithWorkflowID("worker-cc-wf-4")) + _, err = RunAsWorkflow(dbosCtx, blockingWfFunc, 3, WithQueue(workerConcurrencyQueue.name), WithWorkflowID("worker-cc-wf-4")) if err != nil { t.Fatalf("failed to enqueue blocking workflow 4: %v", err) } @@ -494,7 +532,7 @@ func TestWorkerConcurrency(t *testing.T) { if startEvents[1].IsSet || startEvents[2].IsSet || startEvents[3].IsSet { t.Fatal("expected only blocking workflow 1 to start, but others have started") } - workflows, err := dbos.systemDB.ListWorkflows(context.Background(), listWorkflowsDBInput{ + workflows, err := dbosCtx.(*dbosContext).systemDB.ListWorkflows(context.Background(), listWorkflowsDBInput{ status: []WorkflowStatusType{WorkflowStatusEnqueued}, queueName: workerConcurrencyQueue.name, }) @@ -505,12 +543,12 @@ func TestWorkerConcurrency(t *testing.T) { t.Fatalf("expected 3 workflows to be enqueued, got %d", len(workflows)) } - // Stop the queue runner before changing executor ID to avoid race conditions - stopQueueRunner() - // Change the executor ID to a different value - dbos.executorID = "worker-2" + // Stop the queue runner before changing dbosContext ID to avoid race conditions + stopQueueRunner(dbosCtx) + // Change the dbosContext ID to a different value + dbosCtx.(*dbosContext).executorID = "worker-2" // Restart the queue runner - restartQueueRunner() + restartQueueRunner(dbosCtx) // Wait for the second workflow to start on the second worker startEvents[1].Wait() @@ -518,7 +556,7 @@ func TestWorkerConcurrency(t *testing.T) { if startEvents[2].IsSet || startEvents[3].IsSet { t.Fatal("expected only blocking workflow 2 to start, but others have started") } - workflows, err = dbos.systemDB.ListWorkflows(context.Background(), listWorkflowsDBInput{ + workflows, err = dbosCtx.(*dbosContext).systemDB.ListWorkflows(context.Background(), listWorkflowsDBInput{ status: []WorkflowStatusType{WorkflowStatusEnqueued}, queueName: workerConcurrencyQueue.name, }) @@ -531,26 +569,26 @@ func TestWorkerConcurrency(t *testing.T) { // Unlock workflow 1, check wf 3 starts, check 4 stays blocked completeEvents[0].Set() - result1, err := handle1.GetResult(context.Background()) + result1, err := handle1.GetResult() if err != nil { t.Fatalf("failed to get result from blocking workflow 1: %v", err) } if result1 != 0 { t.Fatalf("expected result from blocking workflow 1 to be 0, got %v", result1) } - // Stop the queue runner before changing executor ID to avoid race conditions - stopQueueRunner() - // Change the executor again and wait for the third workflow to start - dbos.executorID = "local" + // Stop the queue runner before changing dbosContext ID to avoid race conditions + stopQueueRunner(dbosCtx) + // Change the dbosContext again and wait for the third workflow to start + dbosCtx.(*dbosContext).executorID = "local" // Restart the queue runner - restartQueueRunner() + restartQueueRunner(dbosCtx) startEvents[2].Wait() // Ensure the fourth workflow is not started yet if startEvents[3].IsSet { t.Fatal("expected only blocking workflow 3 to start, but workflow 4 has started") } // Check that only one workflow is pending - workflows, err = dbos.systemDB.ListWorkflows(context.Background(), listWorkflowsDBInput{ + workflows, err = dbosCtx.(*dbosContext).systemDB.ListWorkflows(context.Background(), listWorkflowsDBInput{ status: []WorkflowStatusType{WorkflowStatusEnqueued}, queueName: workerConcurrencyQueue.name, }) @@ -563,22 +601,22 @@ func TestWorkerConcurrency(t *testing.T) { // Unlock workflow 2 and check wf 4 starts completeEvents[1].Set() - result2, err := handle2.GetResult(context.Background()) + result2, err := handle2.GetResult() if err != nil { t.Fatalf("failed to get result from blocking workflow 2: %v", err) } if result2 != 1 { t.Fatalf("expected result from blocking workflow 2 to be 1, got %v", result2) } - // Stop the queue runner before changing executor ID to avoid race conditions - stopQueueRunner() - // change executor again and wait for the fourth workflow to start - dbos.executorID = "worker-2" + // Stop the queue runner before changing dbosContext ID to avoid race conditions + stopQueueRunner(dbosCtx) + // change dbosContext again and wait for the fourth workflow to start + dbosCtx.(*dbosContext).executorID = "worker-2" // Restart the queue runner - restartQueueRunner() + restartQueueRunner(dbosCtx) startEvents[3].Wait() // Check no workflow is enqueued - workflows, err = dbos.systemDB.ListWorkflows(context.Background(), listWorkflowsDBInput{ + workflows, err = dbosCtx.(*dbosContext).systemDB.ListWorkflows(context.Background(), listWorkflowsDBInput{ status: []WorkflowStatusType{WorkflowStatusEnqueued}, queueName: workerConcurrencyQueue.name, }) @@ -593,11 +631,11 @@ func TestWorkerConcurrency(t *testing.T) { completeEvents[2].Set() completeEvents[3].Set() - if !queueEntriesAreCleanedUp() { + if !queueEntriesAreCleanedUp(dbosCtx) { t.Fatal("expected queue entries to be cleaned up after global concurrency test") } - dbos.executorID = "local" // Reset executor ID for future tests + dbosCtx.(*dbosContext).executorID = "local" // Reset executor ID for future tests } var ( @@ -606,27 +644,36 @@ var ( workerConcurrencyRecoveryStartEvent2 = NewEvent() workerConcurrencyRecoveryCompleteEvent1 = NewEvent() workerConcurrencyRecoveryCompleteEvent2 = NewEvent() - workerConcurrencyRecoveryBlockingWf1 = WithWorkflow(func(ctx context.Context, input string) (string, error) { +) + +func TestWorkerConcurrencyXRecovery(t *testing.T) { + dbosCtx := setupDBOS(t) + + // Create workflows with dbosContext + workerConcurrencyRecoveryBlockingWf1 := func(ctx DBOSContext, input string) (string, error) { workerConcurrencyRecoveryStartEvent1.Set() workerConcurrencyRecoveryCompleteEvent1.Wait() return input, nil - }) - workerConcurrencyRecoveryBlockingWf2 = WithWorkflow(func(ctx context.Context, input string) (string, error) { + } + RegisterWorkflow(dbosCtx, workerConcurrencyRecoveryBlockingWf1) + workerConcurrencyRecoveryBlockingWf2 := func(ctx DBOSContext, input string) (string, error) { workerConcurrencyRecoveryStartEvent2.Set() workerConcurrencyRecoveryCompleteEvent2.Wait() return input, nil - }) -) + } + RegisterWorkflow(dbosCtx, workerConcurrencyRecoveryBlockingWf2) -func TestWorkerConcurrencyXRecovery(t *testing.T) { - setupDBOS(t) + err := dbosCtx.Launch() + if err != nil { + t.Fatalf("failed to launch DBOS instance: %v", err) + } // Enqueue two workflows on a queue with worker concurrency = 1 - handle1, err := workerConcurrencyRecoveryBlockingWf1(context.Background(), "workflow1", WithQueue(workerConcurrencyRecoveryQueue.name), WithWorkflowID("worker-cc-x-recovery-wf-1")) + handle1, err := RunAsWorkflow(dbosCtx, workerConcurrencyRecoveryBlockingWf1, "workflow1", WithQueue(workerConcurrencyRecoveryQueue.name), WithWorkflowID("worker-cc-x-recovery-wf-1")) if err != nil { t.Fatalf("failed to enqueue blocking workflow 1: %v", err) } - handle2, err := workerConcurrencyRecoveryBlockingWf2(context.Background(), "workflow2", WithQueue(workerConcurrencyRecoveryQueue.name), WithWorkflowID("worker-cc-x-recovery-wf-2")) + handle2, err := RunAsWorkflow(dbosCtx, workerConcurrencyRecoveryBlockingWf2, "workflow2", WithQueue(workerConcurrencyRecoveryQueue.name), WithWorkflowID("worker-cc-x-recovery-wf-2")) if err != nil { t.Fatalf("failed to enqueue blocking workflow 2: %v", err) } @@ -652,7 +699,7 @@ func TestWorkerConcurrencyXRecovery(t *testing.T) { } // Now, manually call the recoverPendingWorkflows method - recoveryHandles, err := recoverPendingWorkflows(context.Background(), []string{"local"}) + recoveryHandles, err := recoverPendingWorkflows(dbosCtx.(*dbosContext), []string{"local"}) if err != nil { t.Fatalf("failed to recover pending workflows: %v", err) } @@ -691,7 +738,7 @@ func TestWorkerConcurrencyXRecovery(t *testing.T) { workerConcurrencyRecoveryCompleteEvent2.Set() // Get result from first workflow - result1, err := handle1.GetResult(context.Background()) + result1, err := handle1.GetResult() if err != nil { t.Fatalf("failed to get result from workflow1: %v", err) } @@ -700,7 +747,7 @@ func TestWorkerConcurrencyXRecovery(t *testing.T) { } // Get result from second workflow - result2, err := handle2.GetResult(context.Background()) + result2, err := handle2.GetResult() if err != nil { t.Fatalf("failed to get result from workflow2: %v", err) } @@ -709,22 +756,29 @@ func TestWorkerConcurrencyXRecovery(t *testing.T) { } // Ensure queueEntriesAreCleanedUp is set to true - if !queueEntriesAreCleanedUp() { + if !queueEntriesAreCleanedUp(dbosCtx) { t.Fatal("expected queue entries to be cleaned up after worker concurrency recovery test") } } var ( - rateLimiterQueue = NewWorkflowQueue("test-rate-limiter-queue", WithRateLimiter(&RateLimiter{Limit: 5, Period: 1.8})) - rateLimiterWorkflow = WithWorkflow(rateLimiterTestWorkflow) + rateLimiterQueue = NewWorkflowQueue("test-rate-limiter-queue", WithRateLimiter(&RateLimiter{Limit: 5, Period: 1.8})) ) -func rateLimiterTestWorkflow(ctx context.Context, _ string) (time.Time, error) { +func rateLimiterTestWorkflow(ctx DBOSContext, _ string) (time.Time, error) { return time.Now(), nil // Return current time } func TestQueueRateLimiter(t *testing.T) { - setupDBOS(t) + dbosCtx := setupDBOS(t) + + // Create workflow with dbosContext + RegisterWorkflow(dbosCtx, rateLimiterTestWorkflow) + + err := dbosCtx.Launch() + if err != nil { + t.Fatalf("failed to launch DBOS instance: %v", err) + } limit := 5 period := 1.8 @@ -738,7 +792,7 @@ func TestQueueRateLimiter(t *testing.T) { // executed simultaneously, followed by a wait of the period, // followed by the next wave. for i := 0; i < limit*numWaves; i++ { - handle, err := rateLimiterWorkflow(context.Background(), "", WithQueue(rateLimiterQueue.name)) + handle, err := RunAsWorkflow(dbosCtx, rateLimiterTestWorkflow, "", WithQueue(rateLimiterQueue.name)) if err != nil { t.Fatalf("failed to enqueue workflow %d: %v", i, err) } @@ -747,7 +801,7 @@ func TestQueueRateLimiter(t *testing.T) { // Get results from all workflows for _, handle := range handles { - result, err := handle.GetResult(context.Background()) + result, err := handle.GetResult() if err != nil { t.Fatalf("failed to get result from workflow: %v", err) } @@ -810,7 +864,7 @@ func TestQueueRateLimiter(t *testing.T) { } // Verify all queue entries eventually get cleaned up. - if !queueEntriesAreCleanedUp() { + if !queueEntriesAreCleanedUp(dbosCtx) { t.Fatal("expected queue entries to be cleaned up after rate limiter test") } } diff --git a/dbos/recovery.go b/dbos/recovery.go index 3dc4aeb8..21d081cb 100644 --- a/dbos/recovery.go +++ b/dbos/recovery.go @@ -1,17 +1,16 @@ package dbos import ( - "context" "strings" ) -func recoverPendingWorkflows(ctx context.Context, executorIDs []string) ([]WorkflowHandle[any], error) { +func recoverPendingWorkflows(ctx *dbosContext, executorIDs []string) ([]WorkflowHandle[any], error) { workflowHandles := make([]WorkflowHandle[any], 0) // List pending workflows for the executors - pendingWorkflows, err := dbos.systemDB.ListWorkflows(ctx, listWorkflowsDBInput{ + pendingWorkflows, err := ctx.systemDB.ListWorkflows(ctx, listWorkflowsDBInput{ status: []WorkflowStatusType{WorkflowStatusPending}, executorIDs: executorIDs, - applicationVersion: dbos.applicationVersion, + applicationVersion: ctx.applicationVersion, }) if err != nil { return nil, err @@ -25,37 +24,29 @@ func recoverPendingWorkflows(ctx context.Context, executorIDs []string) ([]Workf } } - // fmt.Println("Recovering workflow:", workflow.ID, "Name:", workflow.Name, "Input:", workflow.Input, "QueueName:", workflow.QueueName) if workflow.QueueName != "" { - cleared, err := dbos.systemDB.ClearQueueAssignment(ctx, workflow.ID) + cleared, err := ctx.systemDB.ClearQueueAssignment(ctx.ctx, workflow.ID) if err != nil { getLogger().Error("Error clearing queue assignment for workflow", "workflow_id", workflow.ID, "name", workflow.Name, "error", err) continue } if cleared { - workflowHandles = append(workflowHandles, &workflowPollingHandle[any]{workflowID: workflow.ID}) + workflowHandles = append(workflowHandles, &workflowPollingHandle[any]{workflowID: workflow.ID, dbosContext: ctx}) } continue } - registeredWorkflow, exists := registry[workflow.Name] + registeredWorkflow, exists := ctx.workflowRegistry[workflow.Name] if !exists { getLogger().Error("Workflow function not found in registry", "workflow_id", workflow.ID, "name", workflow.Name) continue } // Convert workflow parameters to options - opts := []workflowOption{ + opts := []WorkflowOption{ WithWorkflowID(workflow.ID), } - // XXX we'll figure out the exact timeout/deadline settings later - if workflow.Timeout != 0 { - opts = append(opts, WithTimeout(workflow.Timeout)) - } - if !workflow.Deadline.IsZero() { - opts = append(opts, WithDeadline(workflow.Deadline)) - } - + // Create a workflow context from the executor context handle, err := registeredWorkflow.wrappedFunction(ctx, workflow.Input, opts...) if err != nil { return nil, err diff --git a/dbos/serialization_test.go b/dbos/serialization_test.go index 7e6acc39..d10e0f87 100644 --- a/dbos/serialization_test.go +++ b/dbos/serialization_test.go @@ -17,17 +17,12 @@ import ( [x] Set/get event with user defined types */ -var ( - builtinWf = WithWorkflow(encodingWorkflowBuiltinTypes) - structWf = WithWorkflow(encodingWorkflowStruct) -) - // Builtin types func encodingStepBuiltinTypes(_ context.Context, input int) (int, error) { return input, errors.New("step error") } -func encodingWorkflowBuiltinTypes(ctx context.Context, input string) (string, error) { +func encodingWorkflowBuiltinTypes(ctx DBOSContext, input string) (string, error) { stepResult, err := RunAsStep(ctx, encodingStepBuiltinTypes, 123) return fmt.Sprintf("%d", stepResult), fmt.Errorf("workflow error: %v", err) } @@ -53,7 +48,7 @@ type SimpleStruct struct { B int } -func encodingWorkflowStruct(ctx context.Context, input WorkflowInputStruct) (StepOutputStruct, error) { +func encodingWorkflowStruct(ctx DBOSContext, input WorkflowInputStruct) (StepOutputStruct, error) { return RunAsStep(ctx, encodingStepStruct, StepInputStruct{ A: input.A, B: fmt.Sprintf("%d", input.B), @@ -68,17 +63,21 @@ func encodingStepStruct(ctx context.Context, input StepInputStruct) (StepOutputS } func TestWorkflowEncoding(t *testing.T) { - setupDBOS(t) + executor := setupDBOS(t) + + // Register workflows with executor + RegisterWorkflow(executor, encodingWorkflowBuiltinTypes) + RegisterWorkflow(executor, encodingWorkflowStruct) t.Run("BuiltinTypes", func(t *testing.T) { // Test a workflow that uses a built-in type (string) - directHandle, err := builtinWf(context.Background(), "test") + directHandle, err := RunAsWorkflow(executor, encodingWorkflowBuiltinTypes, "test") if err != nil { t.Fatalf("failed to execute workflow: %v", err) } // Test result and error from direct handle - directHandleResult, err := directHandle.GetResult(context.Background()) + directHandleResult, err := directHandle.GetResult() if directHandleResult != "123" { t.Fatalf("expected direct handle result to be '123', got %v", directHandleResult) } @@ -87,11 +86,11 @@ func TestWorkflowEncoding(t *testing.T) { } // Test result from polling handle - retrieveHandler, err := RetrieveWorkflow[string](directHandle.GetWorkflowID()) + retrieveHandler, err := RetrieveWorkflow[string](executor.(*dbosContext), directHandle.GetWorkflowID()) if err != nil { t.Fatalf("failed to retrieve workflow: %v", err) } - retrievedResult, err := retrieveHandler.GetResult(context.Background()) + retrievedResult, err := retrieveHandler.GetResult() if retrievedResult != "123" { t.Fatalf("expected retrieved result to be '123', got %v", retrievedResult) } @@ -100,7 +99,7 @@ func TestWorkflowEncoding(t *testing.T) { } // Test results from ListWorkflows - workflows, err := dbos.systemDB.ListWorkflows(context.Background(), listWorkflowsDBInput{ + workflows, err := executor.(*dbosContext).systemDB.ListWorkflows(context.Background(), listWorkflowsDBInput{ workflowIDs: []string{directHandle.GetWorkflowID()}, }) if err != nil { @@ -138,7 +137,7 @@ func TestWorkflowEncoding(t *testing.T) { } // Test results from GetWorkflowSteps - steps, err := dbos.systemDB.GetWorkflowSteps(context.Background(), directHandle.GetWorkflowID()) + steps, err := executor.(*dbosContext).systemDB.GetWorkflowSteps(context.Background(), directHandle.GetWorkflowID()) if err != nil { t.Fatalf("failed to get workflow steps: %v", err) } @@ -171,13 +170,13 @@ func TestWorkflowEncoding(t *testing.T) { B: 456, } - directHandle, err := structWf(context.Background(), input) + directHandle, err := RunAsWorkflow(executor, encodingWorkflowStruct, input) if err != nil { t.Fatalf("failed to execute step workflow: %v", err) } // Test result from direct handle - directResult, err := directHandle.GetResult(context.Background()) + directResult, err := directHandle.GetResult() if err != nil { t.Fatalf("expected no error but got: %v", err) } @@ -195,11 +194,11 @@ func TestWorkflowEncoding(t *testing.T) { } // Test result from polling handle - retrieveHandler, err := RetrieveWorkflow[StepOutputStruct](directHandle.GetWorkflowID()) + retrieveHandler, err := RetrieveWorkflow[StepOutputStruct](executor.(*dbosContext), directHandle.GetWorkflowID()) if err != nil { t.Fatalf("failed to retrieve step workflow: %v", err) } - retrievedResult, err := retrieveHandler.GetResult(context.Background()) + retrievedResult, err := retrieveHandler.GetResult() if err != nil { t.Fatalf("expected no error but got: %v", err) } @@ -217,7 +216,7 @@ func TestWorkflowEncoding(t *testing.T) { } // Test results from ListWorkflows - workflows, err := dbos.systemDB.ListWorkflows(context.Background(), listWorkflowsDBInput{ + workflows, err := executor.(*dbosContext).systemDB.ListWorkflows(context.Background(), listWorkflowsDBInput{ workflowIDs: []string{directHandle.GetWorkflowID()}, }) if err != nil { @@ -259,7 +258,7 @@ func TestWorkflowEncoding(t *testing.T) { } // Test results from GetWorkflowSteps - steps, err := dbos.systemDB.GetWorkflowSteps(context.Background(), directHandle.GetWorkflowID()) + steps, err := executor.(*dbosContext).systemDB.GetWorkflowSteps(context.Background(), directHandle.GetWorkflowID()) if err != nil { t.Fatalf("failed to get workflow steps: %v", err) } @@ -301,7 +300,7 @@ type UserDefinedEventData struct { } `json:"details"` } -func setEventUserDefinedTypeWorkflow(ctx context.Context, input string) (string, error) { +func setEventUserDefinedTypeWorkflow(ctx DBOSContext, input string) (string, error) { eventData := UserDefinedEventData{ ID: 42, Name: "test-event", @@ -314,27 +313,28 @@ func setEventUserDefinedTypeWorkflow(ctx context.Context, input string) (string, }, } - err := SetEvent(ctx, WorkflowSetEventInput[UserDefinedEventData]{Key: input, Message: eventData}) + err := SetEvent(ctx, WorkflowSetEventInputGeneric[UserDefinedEventData]{Key: input, Message: eventData}) if err != nil { return "", err } return "user-defined-event-set", nil } -var setEventUserDefinedTypeWf = WithWorkflow(setEventUserDefinedTypeWorkflow) - func TestSetEventSerialize(t *testing.T) { - setupDBOS(t) + executor := setupDBOS(t) + + // Register workflow with executor + RegisterWorkflow(executor, setEventUserDefinedTypeWorkflow) t.Run("SetEventUserDefinedType", func(t *testing.T) { // Start a workflow that sets an event with a user-defined type - setHandle, err := setEventUserDefinedTypeWf(context.Background(), "user-defined-key") + setHandle, err := RunAsWorkflow(executor, setEventUserDefinedTypeWorkflow, "user-defined-key") if err != nil { t.Fatalf("failed to start workflow with user-defined event type: %v", err) } // Wait for the workflow to complete - result, err := setHandle.GetResult(context.Background()) + result, err := setHandle.GetResult() if err != nil { t.Fatalf("failed to get result from user-defined event workflow: %v", err) } @@ -343,7 +343,7 @@ func TestSetEventSerialize(t *testing.T) { } // Retrieve the event to verify it was properly serialized and can be deserialized - retrievedEvent, err := GetEvent[UserDefinedEventData](context.Background(), WorkflowGetEventInput{ + retrievedEvent, err := GetEvent[UserDefinedEventData](executor, WorkflowGetEventInput{ TargetWorkflowID: setHandle.GetWorkflowID(), Key: "user-defined-key", Timeout: 3 * time.Second, @@ -374,8 +374,7 @@ func TestSetEventSerialize(t *testing.T) { }) } - -func sendUserDefinedTypeWorkflow(ctx context.Context, destinationID string) (string, error) { +func sendUserDefinedTypeWorkflow(ctx DBOSContext, destinationID string) (string, error) { // Create an instance of our user-defined type inside the workflow sendData := UserDefinedEventData{ ID: 42, @@ -402,7 +401,7 @@ func sendUserDefinedTypeWorkflow(ctx context.Context, destinationID string) (str return "user-defined-message-sent", nil } -func recvUserDefinedTypeWorkflow(ctx context.Context, input string) (UserDefinedEventData, error) { +func recvUserDefinedTypeWorkflow(ctx DBOSContext, input string) (UserDefinedEventData, error) { // Receive the user-defined type message result, err := Recv[UserDefinedEventData](ctx, WorkflowRecvInput{ Topic: "user-defined-topic", @@ -411,27 +410,28 @@ func recvUserDefinedTypeWorkflow(ctx context.Context, input string) (UserDefined return result, err } -var sendUserDefinedTypeWf = WithWorkflow(sendUserDefinedTypeWorkflow) -var recvUserDefinedTypeWf = WithWorkflow(recvUserDefinedTypeWorkflow) - func TestSendSerialize(t *testing.T) { - setupDBOS(t) + executor := setupDBOS(t) + + // Register workflows with executor + RegisterWorkflow(executor, sendUserDefinedTypeWorkflow) + RegisterWorkflow(executor, recvUserDefinedTypeWorkflow) t.Run("SendUserDefinedType", func(t *testing.T) { // Start a receiver workflow first - recvHandle, err := recvUserDefinedTypeWf(context.Background(), "recv-input") + recvHandle, err := RunAsWorkflow(executor, recvUserDefinedTypeWorkflow, "recv-input") if err != nil { t.Fatalf("failed to start receive workflow: %v", err) } // Start a sender workflow that sends a message with a user-defined type - sendHandle, err := sendUserDefinedTypeWf(context.Background(), recvHandle.GetWorkflowID()) + sendHandle, err := RunAsWorkflow(executor, sendUserDefinedTypeWorkflow, recvHandle.GetWorkflowID()) if err != nil { t.Fatalf("failed to start workflow with user-defined send type: %v", err) } // Wait for the sender workflow to complete - sendResult, err := sendHandle.GetResult(context.Background()) + sendResult, err := sendHandle.GetResult() if err != nil { t.Fatalf("failed to get result from user-defined send workflow: %v", err) } @@ -440,7 +440,7 @@ func TestSendSerialize(t *testing.T) { } // Wait for the receiver workflow to complete and get the message - receivedData, err := recvHandle.GetResult(context.Background()) + receivedData, err := recvHandle.GetResult() if err != nil { t.Fatalf("failed to get result from receive workflow: %v", err) } diff --git a/dbos/system_database.go b/dbos/system_database.go index 1ccba59f..39db27e3 100644 --- a/dbos/system_database.go +++ b/dbos/system_database.go @@ -42,19 +42,19 @@ type SystemDatabase interface { // Steps RecordOperationResult(ctx context.Context, input recordOperationResultDBInput) error CheckOperationExecution(ctx context.Context, input checkOperationExecutionDBInput) (*recordedResult, error) - GetWorkflowSteps(ctx context.Context, workflowID string) ([]StepInfo, error) + GetWorkflowSteps(ctx context.Context, workflowID string) ([]stepInfo, error) // Communication (special steps) - Send(ctx context.Context, input workflowSendInputInternal) error + Send(ctx context.Context, input WorkflowSendInputInternal) error Recv(ctx context.Context, input WorkflowRecvInput) (any, error) - SetEvent(ctx context.Context, input workflowSetEventInputInternal) error + SetEvent(ctx context.Context, input WorkflowSetEventInput) error GetEvent(ctx context.Context, input WorkflowGetEventInput) (any, error) // Timers (special steps) Sleep(ctx context.Context, duration time.Duration) (time.Duration, error) // Queues - DequeueWorkflows(ctx context.Context, queue WorkflowQueue) ([]dequeuedWorkflow, error) + DequeueWorkflows(ctx context.Context, queue WorkflowQueue, executorID, applicationVersion string) ([]dequeuedWorkflow, error) ClearQueueAssignment(ctx context.Context, workflowID string) (bool, error) } @@ -644,6 +644,12 @@ func (s *systemDatabase) AwaitWorkflowResult(ctx context.Context, workflowID str query := `SELECT status, output, error FROM dbos.workflow_status WHERE workflow_uuid = $1` var status WorkflowStatusType for { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + row := s.pool.QueryRow(ctx, query, workflowID) var outputString *string var errorStr *string @@ -670,7 +676,7 @@ func (s *systemDatabase) AwaitWorkflowResult(ctx context.Context, workflowID str case WorkflowStatusCancelled: return nil, newAwaitedWorkflowCancelledError(workflowID) default: - time.Sleep(1 * time.Second) // Wait before checking again + time.Sleep(1 * time.Second) } } } @@ -936,7 +942,7 @@ func (s *systemDatabase) CheckOperationExecution(ctx context.Context, input chec return result, nil } -type StepInfo struct { +type stepInfo struct { FunctionID int FunctionName string Output any @@ -944,7 +950,7 @@ type StepInfo struct { ChildWorkflowID string } -func (s *systemDatabase) GetWorkflowSteps(ctx context.Context, workflowID string) ([]StepInfo, error) { +func (s *systemDatabase) GetWorkflowSteps(ctx context.Context, workflowID string) ([]stepInfo, error) { query := `SELECT function_id, function_name, output, error, child_workflow_id FROM dbos.operation_outputs WHERE workflow_uuid = $1` @@ -955,9 +961,9 @@ func (s *systemDatabase) GetWorkflowSteps(ctx context.Context, workflowID string } defer rows.Close() - var steps []StepInfo + var steps []stepInfo for rows.Next() { - var step StepInfo + var step stepInfo var outputString *string var errorString *string var childWorkflowID *string @@ -1115,16 +1121,16 @@ func (s *systemDatabase) notificationListenerLoop(ctx context.Context) { const _DBOS_NULL_TOPIC = "__null__topic__" -type workflowSendInputInternal struct { - destinationID string - message any - topic string +type WorkflowSendInputInternal struct { + DestinationID string + Message any + Topic string } // Send is a special type of step that sends a message to another workflow. // Can be called both within a workflow (as a step) or outside a workflow (directly). // When called within a workflow: durability and the function run in the same transaction, and we forbid nested step execution -func (s *systemDatabase) Send(ctx context.Context, input workflowSendInputInternal) error { +func (s *systemDatabase) Send(ctx context.Context, input WorkflowSendInputInternal) error { functionName := "DBOS.send" // Get workflow state from context (optional for Send as we can send from outside a workflow) @@ -1166,22 +1172,22 @@ func (s *systemDatabase) Send(ctx context.Context, input workflowSendInputIntern // Set default topic if not provided topic := _DBOS_NULL_TOPIC - if len(input.topic) > 0 { - topic = input.topic + if len(input.Topic) > 0 { + topic = input.Topic } // Serialize the message. It must have been registered with encoding/gob by the user if not a basic type. - messageString, err := serialize(input.message) + messageString, err := serialize(input.Message) if err != nil { return fmt.Errorf("failed to serialize message: %w", err) } insertQuery := `INSERT INTO dbos.notifications (destination_uuid, topic, message) VALUES ($1, $2, $3)` - _, err = tx.Exec(ctx, insertQuery, input.destinationID, topic, messageString) + _, err = tx.Exec(ctx, insertQuery, input.DestinationID, topic, messageString) if err != nil { // Check for foreign key violation (destination workflow doesn't exist) if pgErr, ok := err.(*pgconn.PgError); ok && pgErr.Code == "23503" { - return newNonExistentWorkflowError(input.destinationID) + return newNonExistentWorkflowError(input.DestinationID) } return fmt.Errorf("failed to insert notification: %w", err) } @@ -1357,12 +1363,12 @@ func (s *systemDatabase) Recv(ctx context.Context, input WorkflowRecvInput) (any return message, nil } -type workflowSetEventInputInternal struct { - key string - message any +type WorkflowSetEventInput struct { + Key string + Message any } -func (s *systemDatabase) SetEvent(ctx context.Context, input workflowSetEventInputInternal) error { +func (s *systemDatabase) SetEvent(ctx context.Context, input WorkflowSetEventInput) error { functionName := "DBOS.setEvent" // Get workflow state from context @@ -1400,7 +1406,7 @@ func (s *systemDatabase) SetEvent(ctx context.Context, input workflowSetEventInp } // Serialize the message. It must have been registered with encoding/gob by the user if not a basic type. - messageString, err := serialize(input.message) + messageString, err := serialize(input.Message) if err != nil { return fmt.Errorf("failed to serialize message: %w", err) } @@ -1411,7 +1417,7 @@ func (s *systemDatabase) SetEvent(ctx context.Context, input workflowSetEventInp ON CONFLICT (workflow_uuid, key) DO UPDATE SET value = EXCLUDED.value` - _, err = tx.Exec(ctx, insertQuery, wfState.workflowID, input.key, messageString) + _, err = tx.Exec(ctx, insertQuery, wfState.workflowID, input.Key, messageString) if err != nil { return fmt.Errorf("failed to insert/update workflow event: %w", err) } @@ -1567,7 +1573,8 @@ type dequeuedWorkflow struct { input string } -func (s *systemDatabase) DequeueWorkflows(ctx context.Context, queue WorkflowQueue) ([]dequeuedWorkflow, error) { +// TODO input struct +func (s *systemDatabase) DequeueWorkflows(ctx context.Context, queue WorkflowQueue, executorID, applicationVersion string) ([]dequeuedWorkflow, error) { // Begin transaction with snapshot isolation tx, err := s.pool.Begin(ctx) if err != nil { @@ -1638,7 +1645,7 @@ func (s *systemDatabase) DequeueWorkflows(ctx context.Context, queue WorkflowQue pendingWorkflowsDict[executorIDRow] = taskCount } - localPendingWorkflows := pendingWorkflowsDict[dbos.executorID] + localPendingWorkflows := pendingWorkflowsDict[executorID] // Check worker concurrency limit if queue.workerConcurrency != nil { @@ -1705,7 +1712,7 @@ func (s *systemDatabase) DequeueWorkflows(ctx context.Context, queue WorkflowQue } // Execute the query to get workflow IDs - rows, err := tx.Query(ctx, query, queue.name, WorkflowStatusEnqueued, dbos.applicationVersion) + rows, err := tx.Query(ctx, query, queue.name, WorkflowStatusEnqueued, applicationVersion) if err != nil { return nil, fmt.Errorf("failed to query enqueued workflows: %w", err) } @@ -1755,8 +1762,8 @@ func (s *systemDatabase) DequeueWorkflows(ctx context.Context, queue WorkflowQue var inputString *string err := tx.QueryRow(ctx, updateQuery, WorkflowStatusPending, - dbos.applicationVersion, - dbos.executorID, + applicationVersion, + executorID, startTimeMs, id).Scan(&retWorkflow.name, &inputString) diff --git a/dbos/utils_test.go b/dbos/utils_test.go index d2f04731..fc839d2a 100644 --- a/dbos/utils_test.go +++ b/dbos/utils_test.go @@ -12,7 +12,7 @@ import ( "github.com/jackc/pgx/v5" ) -func getDatabaseURL(t *testing.T) string { +func getDatabaseURL() string { databaseURL := os.Getenv("DBOS_SYSTEM_DATABASE_URL") if databaseURL == "" { password := os.Getenv("PGPASSWORD") @@ -25,10 +25,10 @@ func getDatabaseURL(t *testing.T) string { } /* Test database setup */ -func setupDBOS(t *testing.T) { +func setupDBOS(t *testing.T) DBOSContext { t.Helper() - databaseURL := getDatabaseURL(t) + databaseURL := getDatabaseURL() // Clean up the test database parsedURL, err := pgx.ParseConfig(databaseURL) @@ -54,28 +54,26 @@ func setupDBOS(t *testing.T) { t.Fatalf("failed to drop test database: %v", err) } - err = Initialize(Config{ + dbosContext, err := NewDBOSContext(Config{ DatabaseURL: databaseURL, AppName: "test-app", }) if err != nil { t.Fatalf("failed to create DBOS instance: %v", err) } - - err = Launch() - if err != nil { - t.Fatalf("failed to launch DBOS instance: %v", err) - } - - if dbos == nil { + if dbosContext == nil { t.Fatal("expected DBOS instance but got nil") } // Register cleanup to run after test completes t.Cleanup(func() { fmt.Println("Cleaning up DBOS instance...") - Shutdown() + if dbosContext != nil { + dbosContext.Shutdown() + } }) + + return dbosContext } /* Event struct provides a simple synchronization primitive that can be used to signal between goroutines. */ @@ -115,27 +113,32 @@ func (e *Event) Clear() { /* Helpers */ // stopQueueRunner stops the queue runner for testing purposes -func stopQueueRunner() { - if dbos != nil && dbos.queueRunnerCancelFunc != nil { - dbos.queueRunnerCancelFunc() - // Wait for queue runner to finish - <-dbos.queueRunnerDone +func stopQueueRunner(ctx DBOSContext) { + if ctx != nil { + exec := ctx.(*dbosContext) + if exec.queueRunnerCancelFunc != nil { + exec.queueRunnerCancelFunc() + // Wait for queue runner to finish + <-exec.queueRunnerDone + } } } // restartQueueRunner restarts the queue runner for testing purposes -func restartQueueRunner() { - if dbos != nil { +func restartQueueRunner(ctx DBOSContext) { + if ctx != nil { + exec := ctx.(*dbosContext) // Create new context and cancel function + // FIXME: cancellation now has to go through the DBOSContext ctx, cancel := context.WithCancel(context.Background()) - dbos.queueRunnerCtx = ctx - dbos.queueRunnerCancelFunc = cancel - dbos.queueRunnerDone = make(chan struct{}) + exec.queueRunnerCtx = ctx + exec.queueRunnerCancelFunc = cancel + exec.queueRunnerDone = make(chan struct{}) // Start the queue runner in a goroutine go func() { - defer close(dbos.queueRunnerDone) - queueRunner(ctx) + defer close(exec.queueRunnerDone) + queueRunner(exec) }() } } @@ -152,12 +155,13 @@ func equal(a, b []int) bool { return true } -func queueEntriesAreCleanedUp() bool { +func queueEntriesAreCleanedUp(ctx DBOSContext) bool { maxTries := 10 success := false for range maxTries { // Begin transaction - tx, err := dbos.systemDB.(*systemDatabase).pool.Begin(context.Background()) + exec := ctx.(*dbosContext) + tx, err := exec.systemDB.(*systemDatabase).pool.Begin(context.Background()) if err != nil { return false } diff --git a/dbos/workflow.go b/dbos/workflow.go index 467d6ab8..e321f506 100644 --- a/dbos/workflow.go +++ b/dbos/workflow.go @@ -8,7 +8,6 @@ import ( "math" "reflect" "runtime" - "sync" "time" "github.com/google/uuid" @@ -55,17 +54,16 @@ type WorkflowStatus struct { } // workflowState holds the runtime state for a workflow execution -// TODO: this should be an internal type. Workflows should have aptly named getters to access the state type workflowState struct { workflowID string - stepCounter int + stepID int isWithinStep bool } // NextStepID returns the next step ID and increments the counter func (ws *workflowState) NextStepID() int { - ws.stepCounter++ - return ws.stepCounter + ws.stepID++ + return ws.stepID } /********************************/ @@ -79,26 +77,27 @@ type workflowOutcome[R any] struct { } type WorkflowHandle[R any] interface { - GetResult(ctx context.Context) (R, error) + GetResult() (R, error) GetStatus() (WorkflowStatus, error) - GetWorkflowID() string // XXX we could have a base struct with GetWorkflowID and then embed it in the implementations + GetWorkflowID() string } // workflowHandle is a concrete implementation of WorkflowHandle type workflowHandle[R any] struct { workflowID string outcomeChan chan workflowOutcome[R] + dbosContext DBOSContext } // GetResult waits for the workflow to complete and returns the result -func (h *workflowHandle[R]) GetResult(ctx context.Context) (R, error) { +func (h *workflowHandle[R]) GetResult() (R, error) { outcome, ok := <-h.outcomeChan // Blocking read if !ok { // Return an error if the channel was closed. In normal operations this would happen if GetResul() is called twice on a handler. The first call should get the buffered result, the second call find zero values (channel is empty and closed). return *new(R), errors.New("workflow result channel is already closed. Did you call GetResult() twice on the same workflow handle?") } // If we are calling GetResult inside a workflow, record the result as a step result - parentWorkflowState, ok := ctx.Value(workflowStateKey).(*workflowState) + parentWorkflowState, ok := h.dbosContext.(*dbosContext).ctx.Value(workflowStateKey).(*workflowState) isChildWorkflow := ok && parentWorkflowState != nil if isChildWorkflow { encodedOutput, encErr := serialize(outcome.result) @@ -112,7 +111,7 @@ func (h *workflowHandle[R]) GetResult(ctx context.Context) (R, error) { output: encodedOutput, err: outcome.err, } - recordResultErr := dbos.systemDB.RecordChildGetResult(ctx, recordGetResultInput) + recordResultErr := h.dbosContext.(*dbosContext).systemDB.RecordChildGetResult(h.dbosContext.(*dbosContext).ctx, recordGetResultInput) if recordResultErr != nil { getLogger().Error("failed to record get result", "error", recordResultErr) return *new(R), newWorkflowExecutionError(parentWorkflowState.workflowID, fmt.Sprintf("recording child workflow result: %v", recordResultErr)) @@ -123,8 +122,7 @@ func (h *workflowHandle[R]) GetResult(ctx context.Context) (R, error) { // GetStatus returns the current status of the workflow from the database func (h *workflowHandle[R]) GetStatus() (WorkflowStatus, error) { - ctx := context.Background() - workflowStatuses, err := dbos.systemDB.ListWorkflows(ctx, listWorkflowsDBInput{ + workflowStatuses, err := h.dbosContext.(*dbosContext).systemDB.ListWorkflows(h.dbosContext.(*dbosContext).ctx, listWorkflowsDBInput{ workflowIDs: []string{h.workflowID}, }) if err != nil { @@ -141,11 +139,14 @@ func (h *workflowHandle[R]) GetWorkflowID() string { } type workflowPollingHandle[R any] struct { - workflowID string + workflowID string + dbosContext DBOSContext } -func (h *workflowPollingHandle[R]) GetResult(ctx context.Context) (R, error) { - result, err := dbos.systemDB.AwaitWorkflowResult(ctx, h.workflowID) +func (h *workflowPollingHandle[R]) GetResult() (R, error) { + // FIXME this should use a context available to the user, so they can cancel it instead of infinite waiting + ctx := context.Background() + result, err := h.dbosContext.(*dbosContext).systemDB.AwaitWorkflowResult(h.dbosContext.(*dbosContext).ctx, h.workflowID) if result != nil { typedResult, ok := result.(R) if !ok { @@ -167,7 +168,7 @@ func (h *workflowPollingHandle[R]) GetResult(ctx context.Context) (R, error) { output: encodedOutput, err: err, } - recordResultErr := dbos.systemDB.RecordChildGetResult(ctx, recordGetResultInput) + recordResultErr := h.dbosContext.(*dbosContext).systemDB.RecordChildGetResult(h.dbosContext.(*dbosContext).ctx, recordGetResultInput) if recordResultErr != nil { // XXX do we want to fail this? getLogger().Error("failed to record get result", "error", recordResultErr) @@ -180,8 +181,7 @@ func (h *workflowPollingHandle[R]) GetResult(ctx context.Context) (R, error) { // GetStatus returns the current status of the workflow from the database func (h *workflowPollingHandle[R]) GetStatus() (WorkflowStatus, error) { - ctx := context.Background() - workflowStatuses, err := dbos.systemDB.ListWorkflows(ctx, listWorkflowsDBInput{ + workflowStatuses, err := h.dbosContext.(*dbosContext).systemDB.ListWorkflows(h.dbosContext.(*dbosContext).ctx, listWorkflowsDBInput{ workflowIDs: []string{h.workflowID}, }) if err != nil { @@ -200,36 +200,74 @@ func (h *workflowPollingHandle[R]) GetWorkflowID() string { /**********************************/ /******* WORKFLOW REGISTRY *******/ /**********************************/ -type typedErasedWorkflowWrapperFunc func(ctx context.Context, input any, opts ...workflowOption) (WorkflowHandle[any], error) +type GenericWrappedWorkflowFunc[P any, R any] func(ctx DBOSContext, input P, opts ...WorkflowOption) (WorkflowHandle[R], error) +type WrappedWorkflowFunc func(ctx DBOSContext, input any, opts ...WorkflowOption) (WorkflowHandle[any], error) type workflowRegistryEntry struct { - wrappedFunction typedErasedWorkflowWrapperFunc + wrappedFunction WrappedWorkflowFunc maxRetries int } -var registry = make(map[string]workflowRegistryEntry) -var regMutex sync.RWMutex - // Register adds a workflow function to the registry (thread-safe, only once per name) -func registerWorkflow(fqn string, fn typedErasedWorkflowWrapperFunc, maxRetries int) { - regMutex.Lock() - defer regMutex.Unlock() +func registerWorkflow(ctx DBOSContext, workflowName string, fn WrappedWorkflowFunc, maxRetries int) { + // Skip if we don't have a concrete dbosContext + c, ok := ctx.(*dbosContext) + if !ok { + return + } + + c.workflowRegMutex.Lock() + defer c.workflowRegMutex.Unlock() - if _, exists := registry[fqn]; exists { - getLogger().Error("workflow function already registered", "fqn", fqn) - panic(newConflictingRegistrationError(fqn)) + if _, exists := c.workflowRegistry[workflowName]; exists { + getLogger().Error("workflow function already registered", "fqn", workflowName) + panic(newConflictingRegistrationError(workflowName)) } - registry[fqn] = workflowRegistryEntry{ + c.workflowRegistry[workflowName] = workflowRegistryEntry{ wrappedFunction: fn, maxRetries: maxRetries, } } +func registerScheduledWorkflow(ctx DBOSContext, workflowName string, fn WorkflowFunc, cronSchedule string) { + // Skip if we don't have a concrete dbosContext + c, ok := ctx.(*dbosContext) + if !ok { + return + } + + c.getWorkflowScheduler().Start() + var entryID cron.EntryID + entryID, err := c.getWorkflowScheduler().AddFunc(cronSchedule, func() { + // Execute the workflow on the cron schedule once DBOS is launched + if !c.launched { + return + } + // Get the scheduled time from the cron entry + entry := c.getWorkflowScheduler().Entry(entryID) + scheduledTime := entry.Prev + if scheduledTime.IsZero() { + // Use Next if Prev is not set, which will only happen for the first run + scheduledTime = entry.Next + } + wfID := fmt.Sprintf("sched-%s-%s", workflowName, scheduledTime) // XXX we can rethink the format + opts := []WorkflowOption{ + WithWorkflowID(wfID), + WithQueue(_DBOS_INTERNAL_QUEUE_NAME), + withWorkflowName(workflowName), + } + ctx.RunAsWorkflow(ctx, fn, scheduledTime, opts...) + }) + if err != nil { + panic(fmt.Sprintf("failed to register scheduled workflow: %v", err)) + } + getLogger().Info("Registered scheduled workflow", "fqn", workflowName, "cron_schedule", cronSchedule) +} + type workflowRegistrationParams struct { cronSchedule string maxRetries int - // Likely we will allow a name here } type workflowRegistrationOption func(*workflowRegistrationParams) @@ -250,22 +288,28 @@ func WithSchedule(schedule string) workflowRegistrationOption { } } -func WithWorkflow[P any, R any](fn WorkflowFunc[P, R], opts ...workflowRegistrationOption) WorkflowWrapperFunc[P, R] { - if dbos != nil { - getLogger().Warn("WithWorkflow called after DBOS initialization, dynamic registration is not supported") - return nil +// RegisterWorkflow registers the provided function as a durable workflow with the provided DBOSContext workflow registry +// If the workflow is a scheduled workflow (determined by the presence of a cron schedule), it will also register a cron job to execute it +// RegisterWorkflow is generically typed, providing compile-time type checking and allowing us to register the workflow input and output types for gob encoding +// The registered workflow is wrapped in a typed-erased wrapper which performs runtime type checks and conversions +// To execute the workflow, use DBOSContext.RunAsWorkflow +func RegisterWorkflow[P any, R any](ctx DBOSContext, fn GenericWorkflowFunc[P, R], opts ...workflowRegistrationOption) { + if ctx == nil { + panic("ctx cannot be nil") + } + + if fn == nil { + panic("workflow function cannot be nil") } registrationParams := workflowRegistrationParams{ maxRetries: _DEFAULT_MAX_RECOVERY_ATTEMPTS, } + for _, opt := range opts { opt(®istrationParams) } - if fn == nil { - panic("workflow function cannot be nil") - } fqn := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name() // Registry the input/output types for gob encoding @@ -274,72 +318,53 @@ func WithWorkflow[P any, R any](fn WorkflowFunc[P, R], opts ...workflowRegistrat gob.Register(p) gob.Register(r) - // Wrap the function in a durable workflow - wrappedFunction := WorkflowWrapperFunc[P, R](func(ctx context.Context, workflowInput P, opts ...workflowOption) (WorkflowHandle[R], error) { - opts = append(opts, WithWorkflowMaxRetries(registrationParams.maxRetries)) - return runAsWorkflow(ctx, fn, workflowInput, opts...) - }) - - // If this is a scheduled workflow, register a cron job - if registrationParams.cronSchedule != "" { - if reflect.TypeOf(p) != reflect.TypeOf(time.Time{}) { - panic(fmt.Sprintf("scheduled workflow function must accept ScheduledWorkflowInput as input, got %T", p)) - } - if workflowScheduler == nil { - workflowScheduler = cron.New(cron.WithSeconds()) - } - var entryID cron.EntryID - entryID, err := workflowScheduler.AddFunc(registrationParams.cronSchedule, func() { - // Execute the workflow on the cron schedule once DBOS is launched - if dbos == nil { - return - } - // Get the scheduled time from the cron entry - entry := workflowScheduler.Entry(entryID) - scheduledTime := entry.Prev - if scheduledTime.IsZero() { - // Use Next if Prev is not set, which will only happen for the first run - scheduledTime = entry.Next - } - wfID := fmt.Sprintf("sched-%s-%s", fqn, scheduledTime) // XXX we can rethink the format - wrappedFunction(context.Background(), any(scheduledTime).(P), WithWorkflowID(wfID), WithQueue(_DBOS_INTERNAL_QUEUE_NAME)) - }) - if err != nil { - panic(fmt.Sprintf("failed to register scheduled workflow: %v", err)) + // Register a type-erased version of the durable workflow for recovery + typedErasedWorkflow := WorkflowFunc(func(ctx DBOSContext, input any) (any, error) { + // This type check is redundant with the one in the wrapper, but I'd better be safe than sorry + typedInput, ok := input.(P) + if !ok { + return nil, newWorkflowUnexpectedInputType(fqn, fmt.Sprintf("%T", typedInput), fmt.Sprintf("%T", input)) } - getLogger().Info("Registered scheduled workflow", "fqn", fqn, "cron_schedule", registrationParams.cronSchedule) - } + return fn(ctx, typedInput) + }) - // Register a type-erased version of the durable workflow for recovery - typeErasedWrapper := func(ctx context.Context, input any, opts ...workflowOption) (WorkflowHandle[any], error) { + typeErasedWrapper := WrappedWorkflowFunc(func(ctx DBOSContext, input any, opts ...WorkflowOption) (WorkflowHandle[any], error) { typedInput, ok := input.(P) if !ok { return nil, newWorkflowUnexpectedInputType(fqn, fmt.Sprintf("%T", typedInput), fmt.Sprintf("%T", input)) } - handle, err := wrappedFunction(ctx, typedInput, opts...) + opts = append(opts, withWorkflowName(fqn)) // Append the name so ctx.RunAsWorkflow can look it up from the registry to apply registration-time options + handle, err := ctx.RunAsWorkflow(ctx, typedErasedWorkflow, typedInput, opts...) if err != nil { return nil, err } - return &workflowPollingHandle[any]{workflowID: handle.GetWorkflowID()}, nil - } - registerWorkflow(fqn, typeErasedWrapper, registrationParams.maxRetries) + return &workflowPollingHandle[any]{workflowID: handle.GetWorkflowID(), dbosContext: ctx}, nil + }) + registerWorkflow(ctx, fqn, typeErasedWrapper, registrationParams.maxRetries) - return wrappedFunction + // If this is a scheduled workflow, register a cron job + if registrationParams.cronSchedule != "" { + if reflect.TypeOf(p) != reflect.TypeOf(time.Time{}) { + panic(fmt.Sprintf("scheduled workflow function must accept a time.Time as input, got %T", p)) + } + registerScheduledWorkflow(ctx, fqn, typedErasedWorkflow, registrationParams.cronSchedule) + } } /**********************************/ /******* WORKFLOW FUNCTIONS *******/ /**********************************/ -type contextKey string +type DBOSContextKey string -const workflowStateKey contextKey = "workflowState" +const workflowStateKey DBOSContextKey = "workflowState" -type WorkflowFunc[P any, R any] func(ctx context.Context, input P) (R, error) -type WorkflowWrapperFunc[P any, R any] func(ctx context.Context, input P, opts ...workflowOption) (WorkflowHandle[R], error) +type GenericWorkflowFunc[P any, R any] func(ctx DBOSContext, input P) (R, error) +type WorkflowFunc func(ctx DBOSContext, input any) (any, error) type workflowParams struct { + workflowName string workflowID string timeout time.Duration deadline time.Time @@ -348,58 +373,128 @@ type workflowParams struct { maxRetries int } -type workflowOption func(*workflowParams) +type WorkflowOption func(*workflowParams) -func WithWorkflowID(id string) workflowOption { +func WithWorkflowID(id string) WorkflowOption { return func(p *workflowParams) { p.workflowID = id } } -func WithTimeout(timeout time.Duration) workflowOption { +func WithTimeout(timeout time.Duration) WorkflowOption { return func(p *workflowParams) { p.timeout = timeout } } -func WithDeadline(deadline time.Time) workflowOption { +func WithDeadline(deadline time.Time) WorkflowOption { return func(p *workflowParams) { p.deadline = deadline } } -func WithQueue(queueName string) workflowOption { +func WithQueue(queueName string) WorkflowOption { return func(p *workflowParams) { p.queueName = queueName } } -func WithApplicationVersion(version string) workflowOption { +func WithApplicationVersion(version string) WorkflowOption { return func(p *workflowParams) { p.applicationVersion = version } } -func WithWorkflowMaxRetries(maxRetries int) workflowOption { +func withWorkflowName(name string) WorkflowOption { return func(p *workflowParams) { - p.maxRetries = maxRetries + p.workflowName = name } } -func runAsWorkflow[P any, R any](ctx context.Context, fn WorkflowFunc[P, R], input P, opts ...workflowOption) (WorkflowHandle[R], error) { +func RunAsWorkflow[P any, R any](ctx DBOSContext, fn GenericWorkflowFunc[P, R], input P, opts ...WorkflowOption) (WorkflowHandle[R], error) { + if ctx == nil { + return nil, fmt.Errorf("ctx cannot be nil") + } + + // Add the fn name to the options so we can communicate it with DBOSContext.RunAsWorkflow + opts = append(opts, withWorkflowName(runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name())) + + typedErasedWorkflow := WorkflowFunc(func(ctx DBOSContext, input any) (any, error) { + return fn(ctx, input.(P)) + }) + + handle, err := ctx.(*dbosContext).RunAsWorkflow(ctx, typedErasedWorkflow, input, opts...) + if err != nil { + return nil, err + } + + // If we got a polling handle, return its typed version + if pollingHandle, ok := handle.(*workflowPollingHandle[any]); ok { + // We need to convert the polling handle to a typed handle + typedPollingHandle := &workflowPollingHandle[R]{ + workflowID: pollingHandle.workflowID, + dbosContext: ctx, + } + return typedPollingHandle, nil + } + + // Create a typed channel for the user to get a typed handle + if handle, ok := handle.(*workflowHandle[any]); ok { + typedOutcomeChan := make(chan workflowOutcome[R], 1) + + go func() { + defer close(typedOutcomeChan) + outcome := <-handle.outcomeChan + + resultErr := outcome.err + var typedResult R + if typedRes, ok := outcome.result.(R); ok { + typedResult = typedRes + } else { // This should never happen + typedResult = *new(R) + typeErr := fmt.Errorf("unexpected result type: expected %T, got %T", *new(R), outcome.result) + resultErr = errors.Join(resultErr, typeErr) + } + + typedOutcomeChan <- workflowOutcome[R]{ + result: typedResult, + err: resultErr, + } + }() + + typedHandle := &workflowHandle[R]{ + workflowID: handle.workflowID, + outcomeChan: typedOutcomeChan, + dbosContext: handle.dbosContext, + } + + return typedHandle, nil + } + + // Should never happen + return nil, fmt.Errorf("unexpected workflow handle type: %T", handle) +} + +func (c *dbosContext) RunAsWorkflow(_ DBOSContext, fn WorkflowFunc, input any, opts ...WorkflowOption) (WorkflowHandle[any], error) { // Apply options to build params params := workflowParams{ - applicationVersion: dbos.applicationVersion, + applicationVersion: c.GetApplicationVersion(), } for _, opt := range opts { opt(¶ms) } - // First, create a context for the workflow - dbosWorkflowContext := context.Background() + // Lookup the registry for registration-time options + registeredWorkflow, exists := c.workflowRegistry[params.workflowName] + if !exists { + return nil, newNonExistentWorkflowError(params.workflowName) + } + if registeredWorkflow.maxRetries > 0 { + params.maxRetries = registeredWorkflow.maxRetries + } // Check if we are within a workflow (and thus a child workflow) - parentWorkflowState, ok := ctx.Value(workflowStateKey).(*workflowState) + parentWorkflowState, ok := c.Value(workflowStateKey).(*workflowState) isChildWorkflow := ok && parentWorkflowState != nil // TODO Check if cancelled @@ -419,12 +514,12 @@ func runAsWorkflow[P any, R any](ctx context.Context, fn WorkflowFunc[P, R], inp // If this is a child workflow that has already been recorded in operations_output, return directly a polling handle if isChildWorkflow { - childWorkflowID, err := dbos.systemDB.CheckChildWorkflow(dbosWorkflowContext, parentWorkflowState.workflowID, parentWorkflowState.stepCounter) + childWorkflowID, err := c.systemDB.CheckChildWorkflow(c.ctx, parentWorkflowState.workflowID, parentWorkflowState.stepID) if err != nil { return nil, newWorkflowExecutionError(parentWorkflowState.workflowID, fmt.Sprintf("checking child workflow: %v", err)) } if childWorkflowID != nil { - return &workflowPollingHandle[R]{workflowID: *childWorkflowID}, nil + return &workflowPollingHandle[any]{workflowID: *childWorkflowID, dbosContext: c}, nil } } @@ -436,25 +531,25 @@ func runAsWorkflow[P any, R any](ctx context.Context, fn WorkflowFunc[P, R], inp } workflowStatus := WorkflowStatus{ - Name: runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name(), // TODO factor out somewhere else so we dont' have to reflect here + Name: params.workflowName, ApplicationVersion: params.applicationVersion, - ExecutorID: dbos.executorID, + ExecutorID: c.GetExecutorID(), Status: status, ID: workflowID, CreatedAt: time.Now(), Deadline: params.deadline, // TODO compute the deadline based on the timeout Timeout: params.timeout, Input: input, - ApplicationID: dbos.applicationID, + ApplicationID: c.GetApplicationID(), QueueName: params.queueName, } // Init status and record child workflow relationship in a single transaction - tx, err := dbos.systemDB.(*systemDatabase).pool.Begin(dbosWorkflowContext) + tx, err := c.systemDB.(*systemDatabase).pool.Begin(c.ctx) if err != nil { return nil, newWorkflowExecutionError(workflowID, fmt.Sprintf("failed to begin transaction: %v", err)) } - defer tx.Rollback(dbosWorkflowContext) // Rollback if not committed + defer tx.Rollback(c.ctx) // Rollback if not committed // Insert workflow status with transaction insertInput := insertWorkflowStatusDBInput{ @@ -462,7 +557,7 @@ func runAsWorkflow[P any, R any](ctx context.Context, fn WorkflowFunc[P, R], inp maxRetries: params.maxRetries, tx: tx, } - insertStatusResult, err := dbos.systemDB.InsertWorkflowStatus(dbosWorkflowContext, insertInput) + insertStatusResult, err := c.systemDB.InsertWorkflowStatus(c.ctx, insertInput) if err != nil { return nil, err } @@ -470,185 +565,208 @@ func runAsWorkflow[P any, R any](ctx context.Context, fn WorkflowFunc[P, R], inp // Return a polling handle if: we are enqueueing, the workflow is already in a terminal state (success or error), if len(params.queueName) > 0 || insertStatusResult.status == WorkflowStatusSuccess || insertStatusResult.status == WorkflowStatusError { // Commit the transaction to update the number of attempts and/or enact the enqueue - if err := tx.Commit(dbosWorkflowContext); err != nil { + if err := tx.Commit(c.ctx); err != nil { return nil, newWorkflowExecutionError(workflowID, fmt.Sprintf("failed to commit transaction: %v", err)) } - return &workflowPollingHandle[R]{workflowID: workflowStatus.ID}, nil + return &workflowPollingHandle[any]{workflowID: workflowStatus.ID, dbosContext: c}, nil } // Record child workflow relationship if this is a child workflow if isChildWorkflow { // Get the step ID that was used for generating the child workflow ID - stepID := parentWorkflowState.stepCounter + stepID := parentWorkflowState.stepID childInput := recordChildWorkflowDBInput{ parentWorkflowID: parentWorkflowState.workflowID, childWorkflowID: workflowStatus.ID, - stepName: runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name(), // Will need to test this + stepName: params.workflowName, stepID: stepID, tx: tx, } - err = dbos.systemDB.RecordChildWorkflow(dbosWorkflowContext, childInput) + err = c.systemDB.RecordChildWorkflow(c.ctx, childInput) if err != nil { return nil, newWorkflowExecutionError(parentWorkflowState.workflowID, fmt.Sprintf("recording child workflow: %v", err)) } } - // Commit the transaction - if err := tx.Commit(dbosWorkflowContext); err != nil { - return nil, newWorkflowExecutionError(workflowID, fmt.Sprintf("failed to commit transaction: %v", err)) - } - // Channel to receive the outcome from the goroutine // The buffer size of 1 allows the goroutine to send the outcome without blocking // In addition it allows the channel to be garbage collected - outcomeChan := make(chan workflowOutcome[R], 1) - - // Create the handle - handle := &workflowHandle[R]{ - workflowID: workflowStatus.ID, - outcomeChan: outcomeChan, - } + outcomeChan := make(chan workflowOutcome[any], 1) // Create workflow state to track step execution wfState := &workflowState{ - workflowID: workflowStatus.ID, - stepCounter: -1, + workflowID: workflowID, + stepID: -1, } // Run the function in a goroutine - augmentUserContext := context.WithValue(ctx, workflowStateKey, wfState) - dbos.workflowsWg.Add(1) + workflowCtx := WithValue(c, workflowStateKey, wfState) + c.workflowsWg.Add(1) go func() { - defer dbos.workflowsWg.Done() - result, err := fn(augmentUserContext, input) + defer c.workflowsWg.Done() + result, err := fn(workflowCtx, input) status := WorkflowStatusSuccess if err != nil { status = WorkflowStatusError } - recordErr := dbos.systemDB.UpdateWorkflowOutcome(dbosWorkflowContext, updateWorkflowOutcomeDBInput{workflowID: workflowStatus.ID, status: status, err: err, output: result}) + recordErr := c.systemDB.UpdateWorkflowOutcome(c.ctx, updateWorkflowOutcomeDBInput{ + workflowID: workflowID, + status: status, + err: err, + output: result, + }) if recordErr != nil { - outcomeChan <- workflowOutcome[R]{result: *new(R), err: recordErr} + outcomeChan <- workflowOutcome[any]{result: nil, err: recordErr} close(outcomeChan) // Close the channel to signal completion return } - outcomeChan <- workflowOutcome[R]{result: result, err: err} + outcomeChan <- workflowOutcome[any]{result: result, err: err} close(outcomeChan) // Close the channel to signal completion }() - // Run the peer goroutine to handle cancellation and timeout - /* - if dbosWorkflowContext.Done() != nil { - getLogger().Debug("starting goroutine to handle workflow context cancellation or timeout") - go func() { - select { - case <-dbosWorkflowContext.Done(): - // The context was cancelled or timed out: record timeout or cancellation - // CANCEL WORKFLOW HERE - return - } - }() - } - */ + // Commit the transaction + if err := tx.Commit(c.ctx); err != nil { + return nil, newWorkflowExecutionError(workflowID, fmt.Sprintf("failed to commit transaction: %v", err)) + } - return handle, nil + return &workflowHandle[any]{workflowID: workflowID, outcomeChan: outcomeChan, dbosContext: c}, nil } /******************************/ /******* STEP FUNCTIONS *******/ /******************************/ -type StepFunc[P any, R any] func(ctx context.Context, input P) (R, error) +type StepFunc func(ctx context.Context, input any) (any, error) +type GenericStepFunc[P any, R any] func(ctx context.Context, input P) (R, error) + +const StepParamsKey DBOSContextKey = "stepParams" type StepParams struct { MaxRetries int BackoffFactor float64 BaseInterval time.Duration MaxInterval time.Duration -} - -// stepOption is a functional option for configuring step parameters -type stepOption func(*StepParams) - -// WithStepMaxRetries sets the maximum number of retries for a step -func WithStepMaxRetries(maxRetries int) stepOption { - return func(p *StepParams) { - p.MaxRetries = maxRetries + StepName string +} + +// setStepParamDefaults returns a StepParams struct with all defaults properly set +func setStepParamDefaults(params *StepParams, stepName string) *StepParams { + if params == nil { + return &StepParams{ + MaxRetries: 0, // Default to no retries + BackoffFactor: 2.0, + BaseInterval: 100 * time.Millisecond, // Default base interval + MaxInterval: 5 * time.Second, // Default max interval + StepName: typeErasedStepNameToStepName[stepName], + } } -} -// WithBackoffFactor sets the backoff factor for retries (multiplier for exponential backoff) -func WithBackoffFactor(backoffFactor float64) stepOption { - return func(p *StepParams) { - p.BackoffFactor = backoffFactor + // Set defaults for zero values + if params.BackoffFactor == 0 { + params.BackoffFactor = 2.0 // Default backoff factor } -} - -// WithBaseInterval sets the base delay for the first retry -func WithBaseInterval(baseInterval time.Duration) stepOption { - return func(p *StepParams) { - p.BaseInterval = baseInterval + if params.BaseInterval == 0 { + params.BaseInterval = 100 * time.Millisecond // Default base interval + } + if params.MaxInterval == 0 { + params.MaxInterval = 5 * time.Second // Default max interval + } + if params.StepName == "" { + // If the step name is not provided, use the function name + params.StepName = typeErasedStepNameToStepName[stepName] } + + return params } -// WithMaxInterval sets the maximum delay for retries -func WithMaxInterval(maxInterval time.Duration) stepOption { - return func(p *StepParams) { - p.MaxInterval = maxInterval +var typeErasedStepNameToStepName = make(map[string]string) + +func RunAsStep[P any, R any](ctx DBOSContext, fn GenericStepFunc[P, R], input P) (R, error) { + if ctx == nil { + return *new(R), newStepExecutionError("", "", "ctx cannot be nil") } -} -func RunAsStep[P any, R any](ctx context.Context, fn StepFunc[P, R], input P, opts ...stepOption) (R, error) { if fn == nil { return *new(R), newStepExecutionError("", "", "step function cannot be nil") } - stepName := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name() + // Type-erase the function based on its actual type + typeErasedFn := StepFunc(func(ctx context.Context, input any) (any, error) { + typedInput, ok := input.(P) + if !ok { + return nil, newStepExecutionError("", "", fmt.Sprintf("unexpected input type: expected %T, got %T", *new(P), input)) + } + return fn(ctx, typedInput) + }) + + typeErasedStepNameToStepName[runtime.FuncForPC(reflect.ValueOf(typeErasedFn).Pointer()).Name()] = runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name() - // Apply options to build params with defaults - params := StepParams{ - MaxRetries: 0, - BackoffFactor: 2.0, - BaseInterval: 500 * time.Millisecond, - MaxInterval: 1 * time.Hour, + // Call the executor method + result, err := ctx.RunAsStep(ctx, typeErasedFn, input) + if err != nil { + // In case the errors comes from the DBOS step logic, the result will be nil and we must handle it + if result == nil { + return *new(R), err + } + return result.(R), err } - for _, opt := range opts { - opt(¶ms) + + // Type-check and cast the result + typedResult, ok := result.(R) + if !ok { + return *new(R), fmt.Errorf("unexpected result type: expected %T, got %T", *new(R), result) } + return typedResult, nil +} + +func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc, input any) (any, error) { // Get workflow state from context - wfState, ok := ctx.Value(workflowStateKey).(*workflowState) + wfState, ok := c.Value(workflowStateKey).(*workflowState) if !ok || wfState == nil { - return *new(R), newStepExecutionError("", stepName, "workflow state not found in context: are you running this step within a workflow?") + // TODO: try to print step name + return nil, newStepExecutionError("", "", "workflow state not found in context: are you running this step within a workflow?") + } + + if fn == nil { + // TODO: try to print step name + return nil, newStepExecutionError(wfState.workflowID, "", "step function cannot be nil") } + // Look up for step parameters in the context and set defaults + params, ok := c.Value(StepParamsKey).(*StepParams) + if !ok { + params = nil + } + params = setStepParamDefaults(params, runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name()) + // If within a step, just run the function directly if wfState.isWithinStep { - return fn(ctx, input) + return fn(c, input) } - // Get next step ID - stepID := wfState.NextStepID() + // Setup step state + stepState := workflowState{ + workflowID: wfState.workflowID, + stepID: wfState.NextStepID(), // crucially, this increments the step ID on the *workflow* state + isWithinStep: true, + } // Check the step is cancelled, has already completed, or is called with a different name - recordedOutput, err := dbos.systemDB.CheckOperationExecution(ctx, checkOperationExecutionDBInput{ - workflowID: wfState.workflowID, - stepID: stepID, - stepName: stepName, + recordedOutput, err := c.systemDB.CheckOperationExecution(c.ctx, checkOperationExecutionDBInput{ + workflowID: stepState.workflowID, + stepID: stepState.stepID, + stepName: params.StepName, }) if err != nil { - return *new(R), newStepExecutionError(wfState.workflowID, stepName, fmt.Sprintf("checking operation execution: %v", err)) + return nil, newStepExecutionError(stepState.workflowID, params.StepName, fmt.Sprintf("checking operation execution: %v", err)) } if recordedOutput != nil { - return recordedOutput.output.(R), recordedOutput.err + return recordedOutput.output, recordedOutput.err } - // Execute step with retry logic if MaxRetries > 0 - stepState := workflowState{ - workflowID: wfState.workflowID, - stepCounter: wfState.stepCounter, - isWithinStep: true, - } - stepCtx := context.WithValue(ctx, workflowStateKey, &stepState) + // Spawn a child DBOSContext with the step state + stepCtx := WithValue(c, workflowStateKey, &stepState) stepOutput, stepError := fn(stepCtx, input) @@ -665,12 +783,12 @@ func RunAsStep[P any, R any](ctx context.Context, fn StepFunc[P, R], input P, op delay = time.Duration(math.Min(exponentialDelay, float64(params.MaxInterval))) } - getLogger().Error("step failed, retrying", "step_name", stepName, "retry", retry, "max_retries", params.MaxRetries, "delay", delay, "error", stepError) + getLogger().Error("step failed, retrying", "step_name", params.StepName, "retry", retry, "max_retries", params.MaxRetries, "delay", delay, "error", stepError) // Wait before retry select { - case <-ctx.Done(): - return *new(R), newStepExecutionError(wfState.workflowID, stepName, fmt.Sprintf("context cancelled during retry: %v", ctx.Err())) + case <-c.ctx.Done(): + return nil, newStepExecutionError(stepState.workflowID, params.StepName, fmt.Sprintf("context cancelled during retry: %v", c.ctx.Err())) case <-time.After(delay): // Continue to retry } @@ -688,7 +806,7 @@ func RunAsStep[P any, R any](ctx context.Context, fn StepFunc[P, R], input P, op // If max retries reached, create MaxStepRetriesExceeded error if retry == params.MaxRetries { - stepError = newMaxStepRetriesExceededError(wfState.workflowID, stepName, params.MaxRetries, joinedErrors) + stepError = newMaxStepRetriesExceededError(stepState.workflowID, params.StepName, params.MaxRetries, joinedErrors) break } } @@ -696,15 +814,15 @@ func RunAsStep[P any, R any](ctx context.Context, fn StepFunc[P, R], input P, op // Record the final result dbInput := recordOperationResultDBInput{ - workflowID: wfState.workflowID, - stepName: stepName, - stepID: stepID, + workflowID: stepState.workflowID, + stepName: params.StepName, + stepID: stepState.stepID, err: stepError, output: stepOutput, } - recErr := dbos.systemDB.RecordOperationResult(ctx, dbInput) + recErr := c.systemDB.RecordOperationResult(c.ctx, dbInput) if recErr != nil { - return *new(R), newStepExecutionError(wfState.workflowID, stepName, fmt.Sprintf("recording step outcome: %v", recErr)) + return nil, newStepExecutionError(stepState.workflowID, params.StepName, fmt.Sprintf("recording step outcome: %v", recErr)) } return stepOutput, stepError @@ -720,15 +838,22 @@ type WorkflowSendInput[R any] struct { Topic string } +func (c *dbosContext) Send(_ DBOSContext, input WorkflowSendInputInternal) error { + return c.systemDB.Send(c.ctx, input) +} + // Send sends a message to another workflow. // Send automatically registers the type of R for gob encoding -func Send[R any](ctx context.Context, input WorkflowSendInput[R]) error { +func Send[R any](ctx DBOSContext, input WorkflowSendInput[R]) error { + if ctx == nil { + return errors.New("ctx cannot be nil") + } var typedMessage R gob.Register(typedMessage) - return dbos.systemDB.Send(ctx, workflowSendInputInternal{ - destinationID: input.DestinationID, - message: input.Message, - topic: input.Topic, + return ctx.Send(ctx, WorkflowSendInputInternal{ + DestinationID: input.DestinationID, + Message: input.Message, + Topic: input.Topic, }) } @@ -737,8 +862,15 @@ type WorkflowRecvInput struct { Timeout time.Duration } -func Recv[R any](ctx context.Context, input WorkflowRecvInput) (R, error) { - msg, err := dbos.systemDB.Recv(ctx, input) +func (c *dbosContext) Recv(_ DBOSContext, input WorkflowRecvInput) (any, error) { + return c.systemDB.Recv(c.ctx, input) +} + +func Recv[R any](ctx DBOSContext, input WorkflowRecvInput) (R, error) { + if ctx == nil { + return *new(R), errors.New("ctx cannot be nil") + } + msg, err := ctx.Recv(ctx, input) if err != nil { return *new(R), err } @@ -754,20 +886,27 @@ func Recv[R any](ctx context.Context, input WorkflowRecvInput) (R, error) { return typedMessage, nil } -type WorkflowSetEventInput[R any] struct { +type WorkflowSetEventInputGeneric[R any] struct { Key string Message R } +func (c *dbosContext) SetEvent(_ DBOSContext, input WorkflowSetEventInput) error { + return c.systemDB.SetEvent(c.ctx, input) +} + // Sets an event from a workflow. // The event is a key value pair // SetEvent automatically registers the type of R for gob encoding -func SetEvent[R any](ctx context.Context, input WorkflowSetEventInput[R]) error { +func SetEvent[R any](ctx DBOSContext, input WorkflowSetEventInputGeneric[R]) error { + if ctx == nil { + return errors.New("ctx cannot be nil") + } var typedMessage R gob.Register(typedMessage) - return dbos.systemDB.SetEvent(ctx, workflowSetEventInputInternal{ - key: input.Key, - message: input.Message, + return ctx.SetEvent(ctx, WorkflowSetEventInput{ + Key: input.Key, + Message: input.Message, }) } @@ -777,8 +916,15 @@ type WorkflowGetEventInput struct { Timeout time.Duration } -func GetEvent[R any](ctx context.Context, input WorkflowGetEventInput) (R, error) { - value, err := dbos.systemDB.GetEvent(ctx, input) +func (c *dbosContext) GetEvent(_ DBOSContext, input WorkflowGetEventInput) (any, error) { + return c.systemDB.GetEvent(c.ctx, input) +} + +func GetEvent[R any](ctx DBOSContext, input WorkflowGetEventInput) (R, error) { + if ctx == nil { + return *new(R), errors.New("dbosCtx cannot be nil") + } + value, err := ctx.GetEvent(ctx, input) if err != nil { return *new(R), err } @@ -793,8 +939,8 @@ func GetEvent[R any](ctx context.Context, input WorkflowGetEventInput) (R, error return typedValue, nil } -func Sleep(ctx context.Context, duration time.Duration) (time.Duration, error) { - return dbos.systemDB.Sleep(ctx, duration) +func (c *dbosContext) Sleep(duration time.Duration) (time.Duration, error) { + return c.systemDB.Sleep(c.ctx, duration) } /***********************************/ @@ -802,17 +948,32 @@ func Sleep(ctx context.Context, duration time.Duration) (time.Duration, error) { /***********************************/ // GetWorkflowID retrieves the workflow ID from the context if called within a DBOS workflow -func GetWorkflowID(ctx context.Context) (string, error) { - wfState, ok := ctx.Value(workflowStateKey).(*workflowState) +func (c *dbosContext) GetWorkflowID() (string, error) { + wfState, ok := c.ctx.Value(workflowStateKey).(*workflowState) if !ok || wfState == nil { return "", errors.New("not within a DBOS workflow context") } return wfState.workflowID, nil } -func RetrieveWorkflow[R any](workflowID string) (workflowPollingHandle[R], error) { - ctx := context.Background() - workflowStatus, err := dbos.systemDB.ListWorkflows(ctx, listWorkflowsDBInput{ +func (c *dbosContext) RetrieveWorkflow(_ DBOSContext, workflowID string) (WorkflowHandle[any], error) { + workflowStatus, err := c.systemDB.ListWorkflows(c.ctx, listWorkflowsDBInput{ + workflowIDs: []string{workflowID}, + }) + if err != nil { + return nil, fmt.Errorf("failed to retrieve workflow status: %w", err) + } + if len(workflowStatus) == 0 { + return nil, newNonExistentWorkflowError(workflowID) + } + return &workflowPollingHandle[any]{workflowID: workflowID, dbosContext: c}, nil +} + +func RetrieveWorkflow[R any](ctx DBOSContext, workflowID string) (workflowPollingHandle[R], error) { + if ctx == nil { + return workflowPollingHandle[R]{}, errors.New("dbosCtx cannot be nil") + } + workflowStatus, err := ctx.(*dbosContext).systemDB.ListWorkflows(ctx.(*dbosContext).ctx, listWorkflowsDBInput{ workflowIDs: []string{workflowID}, }) if err != nil { @@ -821,5 +982,5 @@ func RetrieveWorkflow[R any](workflowID string) (workflowPollingHandle[R], error if len(workflowStatus) == 0 { return workflowPollingHandle[R]{}, newNonExistentWorkflowError(workflowID) } - return workflowPollingHandle[R]{workflowID: workflowID}, nil + return workflowPollingHandle[R]{workflowID: workflowID, dbosContext: ctx}, nil } diff --git a/dbos/workflows_test.go b/dbos/workflows_test.go index 7fad1361..239da430 100644 --- a/dbos/workflows_test.go +++ b/dbos/workflows_test.go @@ -6,8 +6,7 @@ Test workflow and steps features [x] workflow idempotency [x] workflow DLQ [] workflow conflicting name -[] workflow timeout -[] workflow deadlines +[] workflow timeouts & deadlines (including child workflows) */ import ( @@ -24,39 +23,16 @@ import ( // Global counter for idempotency testing var idempotencyCounter int64 -var ( - simpleWf = WithWorkflow(simpleWorkflow) - simpleWfError = WithWorkflow(simpleWorkflowError) - simpleWfWithStep = WithWorkflow(simpleWorkflowWithStep) - simpleWfWithStepError = WithWorkflow(simpleWorkflowWithStepError) - // struct methods - s = workflowStruct{} - simpleWfStruct = WithWorkflow(s.simpleWorkflow) - simpleWfValue = WithWorkflow(s.simpleWorkflowValue) - // interface method workflow - workflowIface TestWorkflowInterface = &workflowImplementation{ - field: "example", - } - simpleWfIface = WithWorkflow(workflowIface.Execute) - // Generic workflow - wfInt = WithWorkflow(Identity[string]) // FIXME make this an int eventually - // Closure with captured state - prefix = "hello-" - wfClose = WithWorkflow(func(ctx context.Context, in string) (string, error) { - return prefix + in, nil - }) -) - -func simpleWorkflow(ctxt context.Context, input string) (string, error) { +func simpleWorkflow(dbosCtx DBOSContext, input string) (string, error) { return input, nil } -func simpleWorkflowError(ctx context.Context, input string) (int, error) { +func simpleWorkflowError(dbosCtx DBOSContext, input string) (int, error) { return 0, fmt.Errorf("failure") } -func simpleWorkflowWithStep(ctx context.Context, input string) (string, error) { - return RunAsStep(ctx, simpleStep, input) +func simpleWorkflowWithStep(dbosCtx DBOSContext, input string) (string, error) { + return RunAsStep(dbosCtx, simpleStep, input) } func simpleStep(ctx context.Context, input string) (string, error) { @@ -67,8 +43,8 @@ func simpleStepError(ctx context.Context, input string) (string, error) { return "", fmt.Errorf("step failure") } -func simpleWorkflowWithStepError(ctx context.Context, input string) (string, error) { - return RunAsStep(ctx, simpleStepError, input) +func simpleWorkflowWithStepError(dbosCtx DBOSContext, input string) (string, error) { + return RunAsStep(dbosCtx, simpleStepError, input) } // idempotencyWorkflow increments a global counter and returns the input @@ -81,45 +57,67 @@ func incrementCounter(_ context.Context, value int64) (int64, error) { type workflowStruct struct{} // Pointer receiver method -func (w *workflowStruct) simpleWorkflow(ctx context.Context, input string) (string, error) { - return simpleWorkflow(ctx, input) +func (w *workflowStruct) simpleWorkflow(dbosCtx DBOSContext, input string) (string, error) { + return simpleWorkflow(dbosCtx, input) } // Value receiver method on the same struct -func (w workflowStruct) simpleWorkflowValue(ctx context.Context, input string) (string, error) { +func (w workflowStruct) simpleWorkflowValue(dbosCtx DBOSContext, input string) (string, error) { return input + "-value", nil } // interface for workflow methods type TestWorkflowInterface interface { - Execute(ctx context.Context, input string) (string, error) + Execute(dbosCtx DBOSContext, input string) (string, error) } type workflowImplementation struct { field string } -func (w *workflowImplementation) Execute(ctx context.Context, input string) (string, error) { +func (w *workflowImplementation) Execute(dbosCtx DBOSContext, input string) (string, error) { return input + "-" + w.field + "-interface", nil } // Generic workflow function -func Identity[T any](ctx context.Context, in T) (T, error) { +func Identity[T any](dbosCtx DBOSContext, in T) (T, error) { return in, nil } -var ( - anonymousWf = WithWorkflow(func(ctx context.Context, in string) (string, error) { - return "anonymous-" + in, nil - }) -) +func TestWorkflowsRegistration(t *testing.T) { + dbosCtx := setupDBOS(t) -func TestWorkflowsWrapping(t *testing.T) { - setupDBOS(t) + // Setup workflows with executor + RegisterWorkflow(dbosCtx, simpleWorkflow) + RegisterWorkflow(dbosCtx, simpleWorkflowError) + RegisterWorkflow(dbosCtx, simpleWorkflowWithStep) + RegisterWorkflow(dbosCtx, simpleWorkflowWithStepError) + // struct methods + s := workflowStruct{} + RegisterWorkflow(dbosCtx, s.simpleWorkflow) + RegisterWorkflow(dbosCtx, s.simpleWorkflowValue) + // interface method workflow + workflowIface := TestWorkflowInterface(&workflowImplementation{ + field: "example", + }) + RegisterWorkflow(dbosCtx, workflowIface.Execute) + // Generic workflow + RegisterWorkflow(dbosCtx, Identity[int]) + // Closure with captured state + prefix := "hello-" + closureWorkflow := func(dbosCtx DBOSContext, in string) (string, error) { + return prefix + in, nil + } + RegisterWorkflow(dbosCtx, closureWorkflow) + // Anonymous workflow + anonymousWorkflow := func(dbosCtx DBOSContext, in string) (string, error) { + return "anonymous-" + in, nil + } + RegisterWorkflow(dbosCtx, anonymousWorkflow) type testCase struct { name string - workflowFunc func(context.Context, string, ...workflowOption) (any, error) + workflowFunc func(DBOSContext, string, ...WorkflowOption) (any, error) input string expectedResult any expectError bool @@ -129,13 +127,13 @@ func TestWorkflowsWrapping(t *testing.T) { tests := []testCase{ { name: "SimpleWorkflow", - workflowFunc: func(ctx context.Context, input string, opts ...workflowOption) (any, error) { - handle, err := simpleWf(ctx, input, opts...) + workflowFunc: func(dbosCtx DBOSContext, input string, opts ...WorkflowOption) (any, error) { + handle, err := RunAsWorkflow(dbosCtx, simpleWorkflow, input, opts...) if err != nil { return nil, err } - result, err := handle.GetResult(ctx) - _, err2 := handle.GetResult(ctx) + result, err := handle.GetResult() + _, err2 := handle.GetResult() if err2 == nil { t.Fatal("Second call to GetResult should return an error") } @@ -151,12 +149,12 @@ func TestWorkflowsWrapping(t *testing.T) { }, { name: "SimpleWorkflowError", - workflowFunc: func(ctx context.Context, input string, opts ...workflowOption) (any, error) { - handle, err := simpleWfError(ctx, input, opts...) + workflowFunc: func(dbosCtx DBOSContext, input string, opts ...WorkflowOption) (any, error) { + handle, err := RunAsWorkflow(dbosCtx, simpleWorkflowError, input, opts...) if err != nil { return nil, err } - return handle.GetResult(ctx) + return handle.GetResult() }, input: "echo", expectError: true, @@ -164,12 +162,12 @@ func TestWorkflowsWrapping(t *testing.T) { }, { name: "SimpleWorkflowWithStep", - workflowFunc: func(ctx context.Context, input string, opts ...workflowOption) (any, error) { - handle, err := simpleWfWithStep(ctx, input, opts...) + workflowFunc: func(dbosCtx DBOSContext, input string, opts ...WorkflowOption) (any, error) { + handle, err := RunAsWorkflow(dbosCtx, simpleWorkflowWithStep, input, opts...) if err != nil { return nil, err } - return handle.GetResult(ctx) + return handle.GetResult() }, input: "echo", expectedResult: "from step", @@ -177,12 +175,12 @@ func TestWorkflowsWrapping(t *testing.T) { }, { name: "SimpleWorkflowStruct", - workflowFunc: func(ctx context.Context, input string, opts ...workflowOption) (any, error) { - handle, err := simpleWfStruct(ctx, input, opts...) + workflowFunc: func(dbosCtx DBOSContext, input string, opts ...WorkflowOption) (any, error) { + handle, err := RunAsWorkflow(dbosCtx, s.simpleWorkflow, input, opts...) if err != nil { return nil, err } - return handle.GetResult(ctx) + return handle.GetResult() }, input: "echo", expectedResult: "echo", @@ -190,12 +188,12 @@ func TestWorkflowsWrapping(t *testing.T) { }, { name: "ValueReceiverWorkflow", - workflowFunc: func(ctx context.Context, input string, opts ...workflowOption) (any, error) { - handle, err := simpleWfValue(ctx, input, opts...) + workflowFunc: func(dbosCtx DBOSContext, input string, opts ...WorkflowOption) (any, error) { + handle, err := RunAsWorkflow(dbosCtx, s.simpleWorkflowValue, input, opts...) if err != nil { return nil, err } - return handle.GetResult(ctx) + return handle.GetResult() }, input: "echo", expectedResult: "echo-value", @@ -203,12 +201,12 @@ func TestWorkflowsWrapping(t *testing.T) { }, { name: "interfaceMethodWorkflow", - workflowFunc: func(ctx context.Context, input string, opts ...workflowOption) (any, error) { - handle, err := simpleWfIface(ctx, input, opts...) + workflowFunc: func(dbosCtx DBOSContext, input string, opts ...WorkflowOption) (any, error) { + handle, err := RunAsWorkflow(dbosCtx, workflowIface.Execute, input, opts...) if err != nil { return nil, err } - return handle.GetResult(ctx) + return handle.GetResult() }, input: "echo", expectedResult: "echo-example-interface", @@ -216,26 +214,25 @@ func TestWorkflowsWrapping(t *testing.T) { }, { name: "GenericWorkflow", - workflowFunc: func(ctx context.Context, input string, opts ...workflowOption) (any, error) { - // For generic workflow, we need to convert string to int for testing - handle, err := wfInt(ctx, "42", opts...) // FIXME for now this returns a string because sys db accepts this + workflowFunc: func(dbosCtx DBOSContext, input string, opts ...WorkflowOption) (any, error) { + handle, err := RunAsWorkflow(dbosCtx, Identity, 42, opts...) if err != nil { return nil, err } - return handle.GetResult(ctx) + return handle.GetResult() }, input: "42", // input not used in this case - expectedResult: "42", // FIXME make this an int eventually + expectedResult: 42, expectError: false, }, { name: "ClosureWithCapturedState", - workflowFunc: func(ctx context.Context, input string, opts ...workflowOption) (any, error) { - handle, err := wfClose(ctx, input, opts...) + workflowFunc: func(dbosCtx DBOSContext, input string, opts ...WorkflowOption) (any, error) { + handle, err := RunAsWorkflow(dbosCtx, closureWorkflow, input, opts...) if err != nil { return nil, err } - return handle.GetResult(ctx) + return handle.GetResult() }, input: "world", expectedResult: "hello-world", @@ -243,12 +240,12 @@ func TestWorkflowsWrapping(t *testing.T) { }, { name: "AnonymousClosure", - workflowFunc: func(ctx context.Context, input string, opts ...workflowOption) (any, error) { - handle, err := anonymousWf(ctx, input, opts...) + workflowFunc: func(dbosCtx DBOSContext, input string, opts ...WorkflowOption) (any, error) { + handle, err := RunAsWorkflow(dbosCtx, anonymousWorkflow, input, opts...) if err != nil { return nil, err } - return handle.GetResult(ctx) + return handle.GetResult() }, input: "test", expectedResult: "anonymous-test", @@ -256,12 +253,12 @@ func TestWorkflowsWrapping(t *testing.T) { }, { name: "SimpleWorkflowWithStepError", - workflowFunc: func(ctx context.Context, input string, opts ...workflowOption) (any, error) { - handle, err := simpleWfWithStepError(ctx, input, opts...) + workflowFunc: func(dbosCtx DBOSContext, input string, opts ...WorkflowOption) (any, error) { + handle, err := RunAsWorkflow(dbosCtx, simpleWorkflowWithStepError, input, opts...) if err != nil { return nil, err } - return handle.GetResult(ctx) + return handle.GetResult() }, input: "echo", expectError: true, @@ -271,7 +268,7 @@ func TestWorkflowsWrapping(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - result, err := tc.workflowFunc(context.Background(), tc.input, WithWorkflowID(uuid.NewString())) + result, err := tc.workflowFunc(dbosCtx, tc.input, WithWorkflowID(uuid.NewString())) if tc.expectError { if err == nil { @@ -293,11 +290,11 @@ func TestWorkflowsWrapping(t *testing.T) { } func stepWithinAStep(ctx context.Context, input string) (string, error) { - return RunAsStep(ctx, simpleStep, input) + return simpleStep(ctx, input) } -func stepWithinAStepWorkflow(ctx context.Context, input string) (string, error) { - return RunAsStep(ctx, stepWithinAStep, input) +func stepWithinAStepWorkflow(dbosCtx DBOSContext, input string) (string, error) { + return RunAsStep(dbosCtx, stepWithinAStep, input) } // Global counter for retry testing @@ -310,33 +307,33 @@ func stepRetryAlwaysFailsStep(ctx context.Context, input string) (string, error) var stepIdempotencyCounter int -func stepIdempotencyTest(ctx context.Context, input string) (string, error) { +func stepIdempotencyTest(ctx context.Context, input int) (string, error) { stepIdempotencyCounter++ - return input, nil + return "", nil } -func stepRetryWorkflow(ctx context.Context, input string) (string, error) { - RunAsStep(ctx, stepIdempotencyTest, input) - return RunAsStep(ctx, stepRetryAlwaysFailsStep, input, - WithStepMaxRetries(5), - WithBackoffFactor(2.0), - WithBaseInterval(1*time.Millisecond), - WithMaxInterval(10*time.Millisecond)) -} +func stepRetryWorkflow(dbosCtx DBOSContext, input string) (string, error) { + RunAsStep(dbosCtx, stepIdempotencyTest, 1) + stepCtx := WithValue(dbosCtx, StepParamsKey, &StepParams{ + MaxRetries: 5, + BaseInterval: 1 * time.Millisecond, + MaxInterval: 10 * time.Millisecond, + }) -var ( - stepWithinAStepWf = WithWorkflow(stepWithinAStepWorkflow) - stepRetryWf = WithWorkflow(stepRetryWorkflow) -) + return RunAsStep(stepCtx, stepRetryAlwaysFailsStep, input) +} +// TODO: step params func TestSteps(t *testing.T) { - setupDBOS(t) + dbosCtx := setupDBOS(t) - t.Run("StepsMustRunInsideWorkflows", func(t *testing.T) { - ctx := context.Background() + // Create workflows with executor + RegisterWorkflow(dbosCtx, stepWithinAStepWorkflow) + RegisterWorkflow(dbosCtx, stepRetryWorkflow) + t.Run("StepsMustRunInsideWorkflows", func(t *testing.T) { // Attempt to run a step outside of a workflow context - _, err := RunAsStep(ctx, simpleStep, "test") + _, err := RunAsStep(dbosCtx, simpleStep, "test") if err == nil { t.Fatal("expected error when running step outside of workflow context, but got none") } @@ -359,11 +356,11 @@ func TestSteps(t *testing.T) { }) t.Run("StepWithinAStepAreJustFunctions", func(t *testing.T) { - handle, err := stepWithinAStepWf(context.Background(), "test") + handle, err := RunAsWorkflow(dbosCtx, stepWithinAStepWorkflow, "test") if err != nil { t.Fatal("failed to run step within a step:", err) } - result, err := handle.GetResult(context.Background()) + result, err := handle.GetResult() if err != nil { t.Fatal("failed to get result from step within a step:", err) } @@ -371,7 +368,7 @@ func TestSteps(t *testing.T) { t.Fatalf("expected result 'from step', got '%s'", result) } - steps, err := dbos.systemDB.GetWorkflowSteps(context.Background(), handle.GetWorkflowID()) + steps, err := dbosCtx.(*dbosContext).systemDB.GetWorkflowSteps(context.Background(), handle.GetWorkflowID()) if err != nil { t.Fatal("failed to list steps:", err) } @@ -386,12 +383,12 @@ func TestSteps(t *testing.T) { stepIdempotencyCounter = 0 // Execute the workflow - handle, err := stepRetryWf(context.Background(), "test") + handle, err := RunAsWorkflow(dbosCtx, stepRetryWorkflow, "test") if err != nil { t.Fatal("failed to start retry workflow:", err) } - _, err = handle.GetResult(context.Background()) + _, err = handle.GetResult() if err == nil { t.Fatal("expected error from failing workflow but got none") } @@ -426,7 +423,7 @@ func TestSteps(t *testing.T) { } // Verify that the failed step was still recorded in the database - steps, err := dbos.systemDB.GetWorkflowSteps(context.Background(), handle.GetWorkflowID()) + steps, err := dbosCtx.(*dbosContext).systemDB.GetWorkflowSteps(context.Background(), handle.GetWorkflowID()) if err != nil { t.Fatal("failed to get workflow steps:", err) } @@ -452,9 +449,15 @@ func TestSteps(t *testing.T) { }) } -var ( - childWf = WithWorkflow(func(ctx context.Context, i int) (string, error) { - workflowID, err := GetWorkflowID(ctx) +// Functions that create child workflows - moved to test function where executor is available + +// TODO Check timeouts behaviors for parents and children (e.g. awaited cancelled, etc) +func TestChildWorkflow(t *testing.T) { + dbosCtx := setupDBOS(t) + + // Create child workflows with executor + childWf := func(dbosCtx DBOSContext, i int) (string, error) { + workflowID, err := dbosCtx.GetWorkflowID() if err != nil { return "", fmt.Errorf("failed to get workflow ID: %v", err) } @@ -463,15 +466,17 @@ var ( return "", fmt.Errorf("expected parentWf workflow ID to be %s, got %s", expectedCurrentID, workflowID) } // XXX right now the steps of a child workflow start with an incremented step ID, because the first step ID is allocated to the child workflow - return RunAsStep(ctx, simpleStep, "") - }) - parentWf = WithWorkflow(func(ctx context.Context, i int) (string, error) { - workflowID, err := GetWorkflowID(ctx) + return RunAsStep(dbosCtx, simpleStep, "") + } + RegisterWorkflow(dbosCtx, childWf) + + parentWf := func(dbosCtx DBOSContext, i int) (string, error) { + workflowID, err := dbosCtx.GetWorkflowID() if err != nil { return "", fmt.Errorf("failed to get workflow ID: %v", err) } - childHandle, err := childWf(ctx, i) + childHandle, err := RunAsWorkflow(dbosCtx, childWf, i) if err != nil { return "", err } @@ -488,16 +493,18 @@ var ( if childWorkflowID != expectedChildID { return "", fmt.Errorf("expected childWf ID to be %s, got %s", expectedChildID, childWorkflowID) } - return childHandle.GetResult(ctx) - }) - grandParentWf = WithWorkflow(func(ctx context.Context, _ string) (string, error) { + return childHandle.GetResult() + } + RegisterWorkflow(dbosCtx, parentWf) + + grandParentWf := func(dbosCtx DBOSContext, _ string) (string, error) { for i := range 3 { - workflowID, err := GetWorkflowID(ctx) + workflowID, err := dbosCtx.GetWorkflowID() if err != nil { return "", fmt.Errorf("failed to get workflow ID: %v", err) } - childHandle, err := parentWf(ctx, i) + childHandle, err := RunAsWorkflow(dbosCtx, parentWf, i) if err != nil { return "", err } @@ -516,7 +523,7 @@ var ( } // Calling the child a second time should return a polling handle - childHandle, err = parentWf(ctx, i, WithWorkflowID(childHandle.GetWorkflowID())) + childHandle, err = RunAsWorkflow(dbosCtx, parentWf, i, WithWorkflowID(childHandle.GetWorkflowID())) if err != nil { return "", err } @@ -528,32 +535,25 @@ var ( } return "", nil - }) -) - -// TODO Check timeouts behaviors for parents and children (e.g. awaited cancelled, etc) -func TestChildWorkflow(t *testing.T) { - setupDBOS(t) + } + RegisterWorkflow(dbosCtx, grandParentWf) t.Run("ChildWorkflowIDPattern", func(t *testing.T) { - h, err := grandParentWf(context.Background(), "") + h, err := RunAsWorkflow(dbosCtx, grandParentWf, "") if err != nil { t.Fatalf("failed to execute grand parent workflow: %v", err) } - _, err = h.GetResult(context.Background()) + _, err = h.GetResult() if err != nil { t.Fatalf("failed to get result from grand parent workflow: %v", err) } }) } -var ( - idempotencyWf = WithWorkflow(idempotencyWorkflow) - idempotencyWfWithStep = WithWorkflow(idempotencyWorkflowWithStep) -) +// Idempotency workflows moved to test functions -func idempotencyWorkflow(ctx context.Context, input string) (string, error) { - incrementCounter(ctx, 1) +func idempotencyWorkflow(dbosCtx DBOSContext, input string) (string, error) { + RunAsStep(dbosCtx, incrementCounter, int64(1)) return input, nil } @@ -566,15 +566,16 @@ func blockingStep(ctx context.Context, input string) (string, error) { var idempotencyWorkflowWithStepEvent *Event -func idempotencyWorkflowWithStep(ctx context.Context, input string) (int64, error) { - RunAsStep(ctx, incrementCounter, 1) +func idempotencyWorkflowWithStep(dbosCtx DBOSContext, input string) (int64, error) { + RunAsStep(dbosCtx, incrementCounter, int64(1)) idempotencyWorkflowWithStepEvent.Set() - RunAsStep(ctx, blockingStep, input) + RunAsStep(dbosCtx, blockingStep, input) return idempotencyCounter, nil } func TestWorkflowIdempotency(t *testing.T) { - setupDBOS(t) + dbosCtx := setupDBOS(t) + RegisterWorkflow(dbosCtx, idempotencyWorkflow) t.Run("WorkflowExecutedOnlyOnce", func(t *testing.T) { idempotencyCounter = 0 @@ -584,25 +585,29 @@ func TestWorkflowIdempotency(t *testing.T) { // Execute the same workflow twice with the same ID // First execution - handle1, err := idempotencyWf(context.Background(), input, WithWorkflowID(workflowID)) + handle1, err := RunAsWorkflow(dbosCtx, idempotencyWorkflow, input, WithWorkflowID(workflowID)) if err != nil { t.Fatalf("failed to execute workflow first time: %v", err) } - result1, err := handle1.GetResult(context.Background()) + result1, err := handle1.GetResult() if err != nil { t.Fatalf("failed to get result from first execution: %v", err) } // Second execution with the same workflow ID - handle2, err := idempotencyWf(context.Background(), input, WithWorkflowID(workflowID)) + handle2, err := RunAsWorkflow(dbosCtx, idempotencyWorkflow, input, WithWorkflowID(workflowID)) if err != nil { t.Fatalf("failed to execute workflow second time: %v", err) } - result2, err := handle2.GetResult(context.Background()) + result2, err := handle2.GetResult() if err != nil { t.Fatalf("failed to get result from second execution: %v", err) } + if handle1.GetWorkflowID() != handle2.GetWorkflowID() { + t.Fatalf("expected both handles to represent the same workflow ID, got %s and %s", handle2.GetWorkflowID(), handle1.GetWorkflowID()) + } + // Verify the second handle is a polling handle _, ok := handle2.(*workflowPollingHandle[string]) if !ok { @@ -622,7 +627,8 @@ func TestWorkflowIdempotency(t *testing.T) { } func TestWorkflowRecovery(t *testing.T) { - setupDBOS(t) + dbosCtx := setupDBOS(t) + RegisterWorkflow(dbosCtx, idempotencyWorkflowWithStep) t.Run("RecoveryResumeWhereItLeftOff", func(t *testing.T) { // Reset the global counter idempotencyCounter = 0 @@ -631,7 +637,7 @@ func TestWorkflowRecovery(t *testing.T) { input := "recovery-test" idempotencyWorkflowWithStepEvent = NewEvent() blockingStepStopEvent = NewEvent() - handle1, err := idempotencyWfWithStep(context.Background(), input) + handle1, err := RunAsWorkflow(dbosCtx, idempotencyWorkflowWithStep, input) if err != nil { t.Fatalf("failed to execute workflow first time: %v", err) } @@ -639,7 +645,7 @@ func TestWorkflowRecovery(t *testing.T) { idempotencyWorkflowWithStepEvent.Wait() // Wait for the first step to complete. The second spins forever. // Run recovery for pending workflows with "local" executor - recoveredHandles, err := recoverPendingWorkflows(context.Background(), []string{"local"}) + recoveredHandles, err := recoverPendingWorkflows(dbosCtx.(*dbosContext), []string{"local"}) if err != nil { t.Fatalf("failed to recover pending workflows: %v", err) } @@ -668,7 +674,7 @@ func TestWorkflowRecovery(t *testing.T) { } // Using ListWorkflows, retrieve the status of the workflow - workflows, err := dbos.systemDB.ListWorkflows(context.Background(), listWorkflowsDBInput{ + workflows, err := dbosCtx.(*dbosContext).systemDB.ListWorkflows(context.Background(), listWorkflowsDBInput{ workflowIDs: []string{handle1.GetWorkflowID()}, }) if err != nil { @@ -688,7 +694,7 @@ func TestWorkflowRecovery(t *testing.T) { // unlock the workflow & wait for result blockingStepStopEvent.Set() // This will allow the blocking step to complete - result, err := recoveredHandle.GetResult(context.Background()) + result, err := recoveredHandle.GetResult() if err != nil { t.Fatalf("failed to get result from recovered handle: %v", err) } @@ -700,28 +706,32 @@ func TestWorkflowRecovery(t *testing.T) { var ( maxRecoveryAttempts = 20 - deadLetterQueueWf = WithWorkflow(deadLetterQueueWorkflow, WithMaxRetries(maxRecoveryAttempts)) - infiniteDeadLetterQueueWf = WithWorkflow(infiniteDeadLetterQueueWorkflow, WithMaxRetries(-1)) // A negative value means infinite retries deadLetterQueueStartEvent *Event deadLetterQueueEvent *Event recoveryCount int64 ) -func deadLetterQueueWorkflow(ctx context.Context, input string) (int, error) { +func deadLetterQueueWorkflow(ctx DBOSContext, input string) (int, error) { recoveryCount++ - fmt.Printf("Dead letter queue workflow started, recovery count: %d\n", recoveryCount) + wfid, err := ctx.GetWorkflowID() + if err != nil { + return 0, fmt.Errorf("failed to get workflow ID: %v", err) + } + fmt.Printf("Dead letter queue workflow %s started, recovery count: %d\n", wfid, recoveryCount) deadLetterQueueStartEvent.Set() deadLetterQueueEvent.Wait() return 0, nil } -func infiniteDeadLetterQueueWorkflow(ctx context.Context, input string) (int, error) { +func infiniteDeadLetterQueueWorkflow(ctx DBOSContext, input string) (int, error) { deadLetterQueueStartEvent.Set() deadLetterQueueEvent.Wait() return 0, nil } func TestWorkflowDeadLetterQueue(t *testing.T) { - setupDBOS(t) + dbosCtx := setupDBOS(t) + RegisterWorkflow(dbosCtx, deadLetterQueueWorkflow, WithMaxRetries(maxRecoveryAttempts)) + RegisterWorkflow(dbosCtx, infiniteDeadLetterQueueWorkflow, WithMaxRetries(-1)) // A negative value means infinite retries t.Run("DeadLetterQueueBehavior", func(t *testing.T) { deadLetterQueueEvent = NewEvent() @@ -730,7 +740,7 @@ func TestWorkflowDeadLetterQueue(t *testing.T) { // Start a workflow that blocks forever wfID := uuid.NewString() - handle, err := deadLetterQueueWf(context.Background(), "test", WithWorkflowID(wfID)) + handle, err := RunAsWorkflow(dbosCtx, deadLetterQueueWorkflow, "test", WithWorkflowID(wfID)) if err != nil { t.Fatalf("failed to start dead letter queue workflow: %v", err) } @@ -739,7 +749,7 @@ func TestWorkflowDeadLetterQueue(t *testing.T) { // Attempt to recover the blocked workflow the maximum number of times for i := range maxRecoveryAttempts { - _, err := recoverPendingWorkflows(context.Background(), []string{"local"}) + _, err := recoverPendingWorkflows(dbosCtx.(*dbosContext), []string{"local"}) if err != nil { t.Fatalf("failed to recover pending workflows on attempt %d: %v", i+1, err) } @@ -752,7 +762,7 @@ func TestWorkflowDeadLetterQueue(t *testing.T) { } // Verify an additional attempt throws a DLQ error and puts the workflow in the DLQ status - _, err = recoverPendingWorkflows(context.Background(), []string{"local"}) + _, err = recoverPendingWorkflows(dbosCtx.(*dbosContext), []string{"local"}) if err == nil { t.Fatal("expected dead letter queue error but got none") } @@ -775,7 +785,7 @@ func TestWorkflowDeadLetterQueue(t *testing.T) { } // Verify that attempting to start a workflow with the same ID throws a DLQ error - _, err = deadLetterQueueWf(context.Background(), "test", WithWorkflowID(wfID)) + _, err = RunAsWorkflow(dbosCtx, deadLetterQueueWorkflow, "test", WithWorkflowID(wfID)) if err == nil { t.Fatal("expected dead letter queue error when restarting workflow with same ID but got none") } @@ -795,7 +805,7 @@ func TestWorkflowDeadLetterQueue(t *testing.T) { resumedHandle, err := ... // Recover pending workflows again - should work without error - _, err = recoverPendingWorkflows(context.Background(), []string{"local"}) + _, err = recoverPendingWorkflows(executor.(*dbosContext), []string{"local"}) if err != nil { t.Fatalf("failed to recover pending workflows after resume: %v", err) } @@ -829,7 +839,7 @@ func TestWorkflowDeadLetterQueue(t *testing.T) { // Verify that retries of a completed workflow do not raise the DLQ exception for i := 0; i < maxRecoveryAttempts*2; i++ { - _, err = deadLetterQueueWf(context.Background(), "test", WithWorkflowID(wfID)) + _, err = RunAsWorkflow(executor, deadLetterQueueWorkflow, "test", WithWorkflowID(wfID)) if err != nil { t.Fatalf("unexpected error when retrying completed workflow: %v", err) } @@ -844,7 +854,7 @@ func TestWorkflowDeadLetterQueue(t *testing.T) { // Verify that a workflow with MaxRetries=0 (infinite retries) is retried infinitely wfID := uuid.NewString() - handle, err := infiniteDeadLetterQueueWf(context.Background(), "test", WithWorkflowID(wfID)) + handle, err := RunAsWorkflow(dbosCtx, infiniteDeadLetterQueueWorkflow, "test", WithWorkflowID(wfID)) if err != nil { t.Fatalf("failed to start infinite dead letter queue workflow: %v", err) } @@ -854,7 +864,7 @@ func TestWorkflowDeadLetterQueue(t *testing.T) { // Attempt to recover the blocked workflow many times (should never fail) handles := []WorkflowHandle[any]{} for i := range _DEFAULT_MAX_RECOVERY_ATTEMPTS * 2 { - recoveredHandles, err := recoverPendingWorkflows(context.Background(), []string{"local"}) + recoveredHandles, err := recoverPendingWorkflows(dbosCtx.(*dbosContext), []string{"local"}) if err != nil { t.Fatalf("failed to recover pending workflows on attempt %d: %v", i+1, err) } @@ -866,7 +876,7 @@ func TestWorkflowDeadLetterQueue(t *testing.T) { // Complete the workflow deadLetterQueueEvent.Set() - result, err := handle.GetResult(context.Background()) + result, err := handle.GetResult() if err != nil { t.Fatalf("failed to get result from infinite dead letter queue workflow: %v", err) } @@ -876,7 +886,7 @@ func TestWorkflowDeadLetterQueue(t *testing.T) { // Wait for all handles to complete for i, h := range handles { - result, err := h.GetResult(context.Background()) + result, err := h.GetResult() if err != nil { t.Fatalf("failed to get result from handle %d: %v", i, err) } @@ -890,11 +900,15 @@ func TestWorkflowDeadLetterQueue(t *testing.T) { var ( counter = 0 counter1Ch = make(chan time.Time, 100) - _ = WithWorkflow(func(ctx context.Context, scheduledTime time.Time) (string, error) { +) + +func TestScheduledWorkflows(t *testing.T) { + dbosCtx := setupDBOS(t) + RegisterWorkflow(dbosCtx, func(ctx DBOSContext, scheduledTime time.Time) (string, error) { startTime := time.Now() counter++ if counter == 10 { - return "", fmt.Errorf("counter reached 100, stopping workflow") + return "", fmt.Errorf("counter reached 10, stopping workflow") } select { case counter1Ch <- startTime: @@ -902,10 +916,11 @@ var ( } return fmt.Sprintf("Scheduled workflow scheduled at time %v and executed at time %v", scheduledTime, startTime), nil }, WithSchedule("* * * * * *")) // Every second -) -func TestScheduledWorkflows(t *testing.T) { - setupDBOS(t) + err := dbosCtx.Launch() + if err != nil { + t.Fatalf("failed to launch DBOS: %v", err) + } // Helper function to collect execution times collectExecutionTimes := func(ch chan time.Time, target int, timeout time.Duration) ([]time.Time, error) { @@ -956,7 +971,7 @@ func TestScheduledWorkflows(t *testing.T) { // Stop the workflowScheduler and check if it stops executing currentCounter := counter - workflowScheduler.Stop() + dbosCtx.(*dbosContext).getWorkflowScheduler().Stop() time.Sleep(3 * time.Second) // Wait a bit to ensure no more executions if counter >= currentCounter+2 { t.Fatalf("Scheduled workflow continued executing after stopping scheduler: %d (expected < %d)", counter, currentCounter+2) @@ -965,17 +980,9 @@ func TestScheduledWorkflows(t *testing.T) { } var ( - sendWf = WithWorkflow(sendWorkflow) - receiveWf = WithWorkflow(receiveWorkflow) - receiveWfCoordinated = WithWorkflow(receiveWorkflowCoordinated) - sendStructWf = WithWorkflow(sendStructWorkflow) - receiveStructWf = WithWorkflow(receiveStructWorkflow) - sendIdempotencyWf = WithWorkflow(sendIdempotencyWorkflow) sendIdempotencyEvent = NewEvent() - recvIdempotencyWf = WithWorkflow(receiveIdempotencyWorkflow) receiveIdempotencyStartEvent = NewEvent() receiveIdempotencyStopEvent = NewEvent() - sendWithinStepWf = WithWorkflow(workflowThatCallsSendInStep) numConcurrentRecvWfs = 5 concurrentRecvReadyEvents = make([]*Event, numConcurrentRecvWfs) concurrentRecvStartEvent = NewEvent() @@ -986,8 +993,12 @@ type sendWorkflowInput struct { Topic string } -func sendWorkflow(ctx context.Context, input sendWorkflowInput) (string, error) { - err := Send(ctx, WorkflowSendInput[string]{DestinationID: input.DestinationID, Topic: input.Topic, Message: "message1"}) +func sendWorkflow(ctx DBOSContext, input sendWorkflowInput) (string, error) { + err := Send(ctx, WorkflowSendInput[string]{ + DestinationID: input.DestinationID, + Topic: input.Topic, + Message: "message1", + }) if err != nil { return "", err } @@ -1002,7 +1013,7 @@ func sendWorkflow(ctx context.Context, input sendWorkflowInput) (string, error) return "", nil } -func receiveWorkflow(ctx context.Context, topic string) (string, error) { +func receiveWorkflow(ctx DBOSContext, topic string) (string, error) { msg1, err := Recv[string](ctx, WorkflowRecvInput{Topic: topic, Timeout: 3 * time.Second}) if err != nil { return "", err @@ -1018,7 +1029,7 @@ func receiveWorkflow(ctx context.Context, topic string) (string, error) { return msg1 + "-" + msg2 + "-" + msg3, nil } -func receiveWorkflowCoordinated(ctx context.Context, input struct { +func receiveWorkflowCoordinated(ctx DBOSContext, input struct { Topic string i int }) (string, error) { @@ -1037,17 +1048,17 @@ func receiveWorkflowCoordinated(ctx context.Context, input struct { return msg, nil } -func sendStructWorkflow(ctx context.Context, input sendWorkflowInput) (string, error) { +func sendStructWorkflow(ctx DBOSContext, input sendWorkflowInput) (string, error) { testStruct := sendRecvType{Value: "test-struct-value"} err := Send(ctx, WorkflowSendInput[sendRecvType]{DestinationID: input.DestinationID, Topic: input.Topic, Message: testStruct}) return "", err } -func receiveStructWorkflow(ctx context.Context, topic string) (sendRecvType, error) { +func receiveStructWorkflow(ctx DBOSContext, topic string) (sendRecvType, error) { return Recv[sendRecvType](ctx, WorkflowRecvInput{Topic: topic, Timeout: 3 * time.Second}) } -func sendIdempotencyWorkflow(ctx context.Context, input sendWorkflowInput) (string, error) { +func sendIdempotencyWorkflow(ctx DBOSContext, input sendWorkflowInput) (string, error) { err := Send(ctx, WorkflowSendInput[string]{DestinationID: input.DestinationID, Topic: input.Topic, Message: "m1"}) if err != nil { return "", err @@ -1056,7 +1067,7 @@ func sendIdempotencyWorkflow(ctx context.Context, input sendWorkflowInput) (stri return "idempotent-send-completed", nil } -func receiveIdempotencyWorkflow(ctx context.Context, topic string) (string, error) { +func receiveIdempotencyWorkflow(ctx DBOSContext, topic string) (string, error) { msg, err := Recv[string](ctx, WorkflowRecvInput{Topic: topic, Timeout: 3 * time.Second}) if err != nil { return "", err @@ -1067,7 +1078,7 @@ func receiveIdempotencyWorkflow(ctx context.Context, topic string) (string, erro } func stepThatCallsSend(ctx context.Context, input sendWorkflowInput) (string, error) { - err := Send(ctx, WorkflowSendInput[string]{ + err := Send(ctx.(DBOSContext), WorkflowSendInput[string]{ DestinationID: input.DestinationID, Topic: input.Topic, Message: "message-from-step", @@ -1078,7 +1089,7 @@ func stepThatCallsSend(ctx context.Context, input sendWorkflowInput) (string, er return "send-completed", nil } -func workflowThatCallsSendInStep(ctx context.Context, input sendWorkflowInput) (string, error) { +func workflowThatCallsSendInStep(ctx DBOSContext, input sendWorkflowInput) (string, error) { return RunAsStep(ctx, stepThatCallsSend, input) } @@ -1087,51 +1098,61 @@ type sendRecvType struct { } func TestSendRecv(t *testing.T) { - setupDBOS(t) + dbosCtx := setupDBOS(t) + + // Register all send/recv workflows with executor + RegisterWorkflow(dbosCtx, sendWorkflow) + RegisterWorkflow(dbosCtx, receiveWorkflow) + RegisterWorkflow(dbosCtx, receiveWorkflowCoordinated) + RegisterWorkflow(dbosCtx, sendStructWorkflow) + RegisterWorkflow(dbosCtx, receiveStructWorkflow) + RegisterWorkflow(dbosCtx, sendIdempotencyWorkflow) + RegisterWorkflow(dbosCtx, receiveIdempotencyWorkflow) + RegisterWorkflow(dbosCtx, workflowThatCallsSendInStep) t.Run("SendRecvSuccess", func(t *testing.T) { // Start the receive workflow - receiveHandle, err := receiveWf(context.Background(), "test-topic") + receiveHandle, err := RunAsWorkflow(dbosCtx, receiveWorkflow, "test-topic") if err != nil { t.Fatalf("failed to start receive workflow: %v", err) } // Send a message to the receive workflow - handle, err := sendWf(context.Background(), sendWorkflowInput{ + handle, err := RunAsWorkflow(dbosCtx, sendWorkflow, sendWorkflowInput{ DestinationID: receiveHandle.GetWorkflowID(), Topic: "test-topic", }) if err != nil { t.Fatalf("failed to send message: %v", err) } - _, err = handle.GetResult(context.Background()) + _, err = handle.GetResult() if err != nil { t.Fatalf("failed to get result from send workflow: %v", err) } start := time.Now() - result, err := receiveHandle.GetResult(context.Background()) + result, err := receiveHandle.GetResult() if err != nil { t.Fatalf("failed to get result from receive workflow: %v", err) } if result != "message1-message2-message3" { t.Fatalf("expected received message to be 'message1-message2-message3', got '%s'", result) } - // XXX let's see how this works when all the tests run - if time.Since(start) > 5*time.Second { + // XXX This is not a great condition: when all the tests run there's quite some randomness to this + if time.Since(start) > 10*time.Second { t.Fatalf("receive workflow took too long to complete, expected < 5s, got %v", time.Since(start)) } }) t.Run("SendRecvCustomStruct", func(t *testing.T) { // Start the receive workflow - receiveHandle, err := receiveStructWf(context.Background(), "struct-topic") + receiveHandle, err := RunAsWorkflow(dbosCtx, receiveStructWorkflow, "struct-topic") if err != nil { t.Fatalf("failed to start receive workflow: %v", err) } // Send the struct to the receive workflow - sendHandle, err := sendStructWf(context.Background(), sendWorkflowInput{ + sendHandle, err := RunAsWorkflow(dbosCtx, sendStructWorkflow, sendWorkflowInput{ DestinationID: receiveHandle.GetWorkflowID(), Topic: "struct-topic", }) @@ -1139,13 +1160,13 @@ func TestSendRecv(t *testing.T) { t.Fatalf("failed to send struct: %v", err) } - _, err = sendHandle.GetResult(context.Background()) + _, err = sendHandle.GetResult() if err != nil { t.Fatalf("failed to get result from send workflow: %v", err) } // Get the result from receive workflow - result, err := receiveHandle.GetResult(context.Background()) + result, err := receiveHandle.GetResult() if err != nil { t.Fatalf("failed to get result from receive workflow: %v", err) } @@ -1161,7 +1182,7 @@ func TestSendRecv(t *testing.T) { destUUID := uuid.NewString() // Send to non-existent UUID should fail - handle, err := sendWf(context.Background(), sendWorkflowInput{ + handle, err := RunAsWorkflow(dbosCtx, sendWorkflow, sendWorkflowInput{ DestinationID: destUUID, Topic: "testtopic", }) @@ -1169,7 +1190,7 @@ func TestSendRecv(t *testing.T) { t.Fatalf("failed to start send workflow: %v", err) } - _, err = handle.GetResult(context.Background()) + _, err = handle.GetResult() if err == nil { t.Fatal("expected error when sending to non-existent UUID but got none") } @@ -1191,11 +1212,11 @@ func TestSendRecv(t *testing.T) { t.Run("RecvTimeout", func(t *testing.T) { // Create a receive workflow that tries to receive a message but no send happens - receiveHandle, err := receiveWf(context.Background(), "timeout-test-topic") + receiveHandle, err := RunAsWorkflow(dbosCtx, receiveWorkflow, "timeout-test-topic") if err != nil { t.Fatalf("failed to start receive workflow: %v", err) } - result, err := receiveHandle.GetResult(context.Background()) + result, err := receiveHandle.GetResult() if result != "--" { t.Fatalf("expected -- result on timeout, got '%s'", result) } @@ -1205,10 +1226,8 @@ func TestSendRecv(t *testing.T) { }) t.Run("RecvMustRunInsideWorkflows", func(t *testing.T) { - ctx := context.Background() - // Attempt to run Recv outside of a workflow context - _, err := Recv[string](ctx, WorkflowRecvInput{Topic: "test-topic", Timeout: 1 * time.Second}) + _, err := Recv[string](dbosCtx, WorkflowRecvInput{Topic: "test-topic", Timeout: 1 * time.Second}) if err == nil { t.Fatal("expected error when running Recv outside of workflow context, but got none") } @@ -1232,15 +1251,14 @@ func TestSendRecv(t *testing.T) { t.Run("SendOutsideWorkflow", func(t *testing.T) { // Start a receive workflow to have a valid destination - receiveHandle, err := receiveWf(context.Background(), "outside-workflow-topic") + receiveHandle, err := RunAsWorkflow(dbosCtx, receiveWorkflow, "outside-workflow-topic") if err != nil { t.Fatalf("failed to start receive workflow: %v", err) } // Send messages from outside a workflow context (should work now) - ctx := context.Background() for i := range 3 { - err = Send(ctx, WorkflowSendInput[string]{ + err = Send(dbosCtx, WorkflowSendInput[string]{ DestinationID: receiveHandle.GetWorkflowID(), Topic: "outside-workflow-topic", Message: fmt.Sprintf("message%d", i+1), @@ -1251,7 +1269,7 @@ func TestSendRecv(t *testing.T) { } // Verify the receive workflow gets all messages - result, err := receiveHandle.GetResult(context.Background()) + result, err := receiveHandle.GetResult() if err != nil { t.Fatalf("failed to get result from receive workflow: %v", err) } @@ -1261,13 +1279,13 @@ func TestSendRecv(t *testing.T) { }) t.Run("SendRecvIdempotency", func(t *testing.T) { // Start the receive workflow and wait for it to be ready - receiveHandle, err := recvIdempotencyWf(context.Background(), "idempotency-topic") + receiveHandle, err := RunAsWorkflow(dbosCtx, receiveIdempotencyWorkflow, "idempotency-topic") if err != nil { t.Fatalf("failed to start receive idempotency workflow: %v", err) } // Send the message to the receive workflow - sendHandle, err := sendIdempotencyWf(context.Background(), sendWorkflowInput{ + sendHandle, err := RunAsWorkflow(dbosCtx, sendIdempotencyWorkflow, sendWorkflowInput{ DestinationID: receiveHandle.GetWorkflowID(), Topic: "idempotency-topic", }) @@ -1279,21 +1297,21 @@ func TestSendRecv(t *testing.T) { receiveIdempotencyStartEvent.Wait() // Attempt recovering both workflows. There should be only 2 steps recorded after recovery. - recoveredHandles, err := recoverPendingWorkflows(context.Background(), []string{"local"}) + recoveredHandles, err := recoverPendingWorkflows(dbosCtx.(*dbosContext), []string{"local"}) if err != nil { t.Fatalf("failed to recover pending workflows: %v", err) } if len(recoveredHandles) != 2 { t.Fatalf("expected 2 recovered handles, got %d", len(recoveredHandles)) } - steps, err := dbos.systemDB.GetWorkflowSteps(context.Background(), sendHandle.GetWorkflowID()) + steps, err := dbosCtx.(*dbosContext).systemDB.GetWorkflowSteps(context.Background(), sendHandle.GetWorkflowID()) if err != nil { - t.Fatalf("failed to get steps for send idempotency workflow: %v", err) + t.Fatalf("failed to get workflow steps: %v", err) } if len(steps) != 1 { t.Fatalf("expected 1 step in send idempotency workflow, got %d", len(steps)) } - steps, err = dbos.systemDB.GetWorkflowSteps(context.Background(), receiveHandle.GetWorkflowID()) + steps, err = dbosCtx.(*dbosContext).systemDB.GetWorkflowSteps(context.Background(), receiveHandle.GetWorkflowID()) if err != nil { t.Fatalf("failed to get steps for receive idempotency workflow: %v", err) } @@ -1303,7 +1321,7 @@ func TestSendRecv(t *testing.T) { // Unblock the workflows to complete receiveIdempotencyStopEvent.Set() - result, err := receiveHandle.GetResult(context.Background()) + result, err := receiveHandle.GetResult() if err != nil { t.Fatalf("failed to get result from receive idempotency workflow: %v", err) } @@ -1311,7 +1329,7 @@ func TestSendRecv(t *testing.T) { t.Fatalf("expected result to be 'm1', got '%s'", result) } sendIdempotencyEvent.Set() - result, err = sendHandle.GetResult(context.Background()) + result, err = sendHandle.GetResult() if err != nil { t.Fatalf("failed to get result from send idempotency workflow: %v", err) } @@ -1322,13 +1340,13 @@ func TestSendRecv(t *testing.T) { t.Run("SendCannotBeCalledWithinStep", func(t *testing.T) { // Start a receive workflow to have a valid destination - receiveHandle, err := receiveWf(context.Background(), "send-within-step-topic") + receiveHandle, err := RunAsWorkflow(dbosCtx, receiveWorkflow, "send-within-step-topic") if err != nil { t.Fatalf("failed to start receive workflow: %v", err) } // Execute the workflow that tries to call Send within a step - handle, err := sendWithinStepWf(context.Background(), sendWorkflowInput{ + handle, err := RunAsWorkflow(dbosCtx, workflowThatCallsSendInStep, sendWorkflowInput{ DestinationID: receiveHandle.GetWorkflowID(), Topic: "send-within-step-topic", }) @@ -1337,7 +1355,7 @@ func TestSendRecv(t *testing.T) { } // Expect the workflow to fail with the specific error - _, err = handle.GetResult(context.Background()) + _, err = handle.GetResult() if err == nil { t.Fatal("expected error when calling Send within a step, but got none") } @@ -1359,7 +1377,7 @@ func TestSendRecv(t *testing.T) { } // Wait for the receive workflow to time out - result, err := receiveHandle.GetResult(context.Background()) + result, err := receiveHandle.GetResult() if err != nil { t.Fatalf("failed to get result from receive workflow: %v", err) } @@ -1368,7 +1386,7 @@ func TestSendRecv(t *testing.T) { } }) - t.Run("ConcurrentRecv", func(t *testing.T) { + t.Run("TestSendRecv", func(t *testing.T) { // Test concurrent receivers - only 1 should timeout, others should get errors receiveTopic := "concurrent-recv-topic" @@ -1382,7 +1400,7 @@ func TestSendRecv(t *testing.T) { // Start all receivers - they will signal when ready and wait for coordination for i := range numReceivers { concurrentRecvReadyEvents[i] = NewEvent() - receiveHandle, err := receiveWfCoordinated(context.Background(), struct { + receiveHandle, err := RunAsWorkflow(dbosCtx, receiveWorkflowCoordinated, struct { Topic string i int }{ @@ -1409,7 +1427,7 @@ func TestSendRecv(t *testing.T) { for i := range numReceivers { go func(index int) { defer wg.Done() - result, err := receiverHandles[index].GetResult(context.Background()) + result, err := receiverHandles[index].GetResult() if err != nil { errors <- err } else { @@ -1455,12 +1473,8 @@ func TestSendRecv(t *testing.T) { } var ( - setEventWf = WithWorkflow(setEventWorkflow) - getEventWf = WithWorkflow(getEventWorkflow) - setTwoEventsWf = WithWorkflow(setTwoEventsWorkflow) - setEventIdempotencyWf = WithWorkflow(setEventIdempotencyWorkflow) - getEventIdempotencyWf = WithWorkflow(getEventIdempotencyWorkflow) - setEventIdempotencyEvent = NewEvent() + setEventStartIdempotencyEvent = NewEvent() + setEvenStopIdempotencyEvent = NewEvent() getEventStartIdempotencyEvent = NewEvent() getEventStopIdempotencyEvent = NewEvent() setSecondEventSignal = NewEvent() @@ -1471,15 +1485,15 @@ type setEventWorkflowInput struct { Message string } -func setEventWorkflow(ctx context.Context, input setEventWorkflowInput) (string, error) { - err := SetEvent(ctx, WorkflowSetEventInput[string]{Key: input.Key, Message: input.Message}) +func setEventWorkflow(ctx DBOSContext, input setEventWorkflowInput) (string, error) { + err := SetEvent(ctx, WorkflowSetEventInputGeneric[string]{Key: input.Key, Message: input.Message}) if err != nil { return "", err } return "event-set", nil } -func getEventWorkflow(ctx context.Context, input setEventWorkflowInput) (string, error) { +func getEventWorkflow(ctx DBOSContext, input setEventWorkflowInput) (string, error) { result, err := GetEvent[string](ctx, WorkflowGetEventInput{ TargetWorkflowID: input.Key, // Reusing Key field as target workflow ID Key: input.Message, // Reusing Message field as event key @@ -1491,9 +1505,9 @@ func getEventWorkflow(ctx context.Context, input setEventWorkflowInput) (string, return result, nil } -func setTwoEventsWorkflow(ctx context.Context, input setEventWorkflowInput) (string, error) { +func setTwoEventsWorkflow(ctx DBOSContext, input setEventWorkflowInput) (string, error) { // Set the first event - err := SetEvent(ctx, WorkflowSetEventInput[string]{Key: "event1", Message: "first-event-message"}) + err := SetEvent(ctx, WorkflowSetEventInputGeneric[string]{Key: "event1", Message: "first-event-message"}) if err != nil { return "", err } @@ -1502,7 +1516,7 @@ func setTwoEventsWorkflow(ctx context.Context, input setEventWorkflowInput) (str setSecondEventSignal.Wait() // Set the second event - err = SetEvent(ctx, WorkflowSetEventInput[string]{Key: "event2", Message: "second-event-message"}) + err = SetEvent(ctx, WorkflowSetEventInputGeneric[string]{Key: "event2", Message: "second-event-message"}) if err != nil { return "", err } @@ -1510,16 +1524,17 @@ func setTwoEventsWorkflow(ctx context.Context, input setEventWorkflowInput) (str return "two-events-set", nil } -func setEventIdempotencyWorkflow(ctx context.Context, input setEventWorkflowInput) (string, error) { - err := SetEvent(ctx, WorkflowSetEventInput[string]{Key: input.Key, Message: input.Message}) +func setEventIdempotencyWorkflow(ctx DBOSContext, input setEventWorkflowInput) (string, error) { + err := SetEvent(ctx, WorkflowSetEventInputGeneric[string]{Key: input.Key, Message: input.Message}) if err != nil { return "", err } - setEventIdempotencyEvent.Wait() + setEventStartIdempotencyEvent.Set() + setEvenStopIdempotencyEvent.Wait() return "idempotent-set-completed", nil } -func getEventIdempotencyWorkflow(ctx context.Context, input setEventWorkflowInput) (string, error) { +func getEventIdempotencyWorkflow(ctx DBOSContext, input setEventWorkflowInput) (string, error) { result, err := GetEvent[string](ctx, WorkflowGetEventInput{ TargetWorkflowID: input.Key, Key: input.Message, @@ -1534,14 +1549,21 @@ func getEventIdempotencyWorkflow(ctx context.Context, input setEventWorkflowInpu } func TestSetGetEvent(t *testing.T) { - setupDBOS(t) + dbosCtx := setupDBOS(t) + + // Register all set/get event workflows with executor + RegisterWorkflow(dbosCtx, setEventWorkflow) + RegisterWorkflow(dbosCtx, getEventWorkflow) + RegisterWorkflow(dbosCtx, setTwoEventsWorkflow) + RegisterWorkflow(dbosCtx, setEventIdempotencyWorkflow) + RegisterWorkflow(dbosCtx, getEventIdempotencyWorkflow) t.Run("SetGetEventFromWorkflow", func(t *testing.T) { // Clear the signal event before starting setSecondEventSignal.Clear() // Start the workflow that sets two events - setHandle, err := setTwoEventsWf(context.Background(), setEventWorkflowInput{ + setHandle, err := RunAsWorkflow(dbosCtx, setTwoEventsWorkflow, setEventWorkflowInput{ Key: "test-workflow", Message: "unused", }) @@ -1550,7 +1572,7 @@ func TestSetGetEvent(t *testing.T) { } // Start a workflow to get the first event - getFirstEventHandle, err := getEventWf(context.Background(), setEventWorkflowInput{ + getFirstEventHandle, err := RunAsWorkflow(dbosCtx, getEventWorkflow, setEventWorkflowInput{ Key: setHandle.GetWorkflowID(), // Target workflow ID Message: "event1", // Event key }) @@ -1559,7 +1581,7 @@ func TestSetGetEvent(t *testing.T) { } // Verify we can get the first event - firstMessage, err := getFirstEventHandle.GetResult(context.Background()) + firstMessage, err := getFirstEventHandle.GetResult() if err != nil { t.Fatalf("failed to get result from first event workflow: %v", err) } @@ -1571,7 +1593,7 @@ func TestSetGetEvent(t *testing.T) { setSecondEventSignal.Set() // Start a workflow to get the second event - getSecondEventHandle, err := getEventWf(context.Background(), setEventWorkflowInput{ + getSecondEventHandle, err := RunAsWorkflow(dbosCtx, getEventWorkflow, setEventWorkflowInput{ Key: setHandle.GetWorkflowID(), // Target workflow ID Message: "event2", // Event key }) @@ -1580,7 +1602,7 @@ func TestSetGetEvent(t *testing.T) { } // Verify we can get the second event - secondMessage, err := getSecondEventHandle.GetResult(context.Background()) + secondMessage, err := getSecondEventHandle.GetResult() if err != nil { t.Fatalf("failed to get result from second event workflow: %v", err) } @@ -1589,7 +1611,7 @@ func TestSetGetEvent(t *testing.T) { } // Wait for the workflow to complete - result, err := setHandle.GetResult(context.Background()) + result, err := setHandle.GetResult() if err != nil { t.Fatalf("failed to get result from set two events workflow: %v", err) } @@ -1600,7 +1622,7 @@ func TestSetGetEvent(t *testing.T) { t.Run("GetEventFromOutsideWorkflow", func(t *testing.T) { // Start a workflow that sets an event - setHandle, err := setEventWf(context.Background(), setEventWorkflowInput{ + setHandle, err := RunAsWorkflow(dbosCtx, setEventWorkflow, setEventWorkflowInput{ Key: "test-key", Message: "test-message", }) @@ -1609,13 +1631,13 @@ func TestSetGetEvent(t *testing.T) { } // Wait for the event to be set - _, err = setHandle.GetResult(context.Background()) + _, err = setHandle.GetResult() if err != nil { t.Fatalf("failed to get result from set event workflow: %v", err) } // Start a workflow that gets the event from outside the original workflow - message, err := GetEvent[string](context.Background(), WorkflowGetEventInput{ + message, err := GetEvent[string](dbosCtx, WorkflowGetEventInput{ TargetWorkflowID: setHandle.GetWorkflowID(), Key: "test-key", Timeout: 3 * time.Second, @@ -1631,7 +1653,7 @@ func TestSetGetEvent(t *testing.T) { t.Run("GetEventTimeout", func(t *testing.T) { // Try to get an event from a non-existent workflow nonExistentID := uuid.NewString() - message, err := GetEvent[string](context.Background(), WorkflowGetEventInput{ + message, err := GetEvent[string](dbosCtx, WorkflowGetEventInput{ TargetWorkflowID: nonExistentID, Key: "test-key", Timeout: 3 * time.Second, @@ -1644,18 +1666,18 @@ func TestSetGetEvent(t *testing.T) { } // Try to get an event from an existing workflow but with a key that doesn't exist - setHandle, err := setEventWf(context.Background(), setEventWorkflowInput{ + setHandle, err := RunAsWorkflow(dbosCtx, setEventWorkflow, setEventWorkflowInput{ Key: "test-key", Message: "test-message", }) if err != nil { t.Fatal("failed to set event:", err) } - _, err = setHandle.GetResult(context.Background()) + _, err = setHandle.GetResult() if err != nil { t.Fatal("failed to get result from set event workflow:", err) } - message, err = GetEvent[string](context.Background(), WorkflowGetEventInput{ + message, err = GetEvent[string](dbosCtx, WorkflowGetEventInput{ TargetWorkflowID: setHandle.GetWorkflowID(), Key: "non-existent-key", Timeout: 3 * time.Second, @@ -1669,10 +1691,8 @@ func TestSetGetEvent(t *testing.T) { }) t.Run("SetGetEventMustRunInsideWorkflows", func(t *testing.T) { - ctx := context.Background() - // Attempt to run SetEvent outside of a workflow context - err := SetEvent(ctx, WorkflowSetEventInput[string]{Key: "test-key", Message: "test-message"}) + err := SetEvent(dbosCtx, WorkflowSetEventInputGeneric[string]{Key: "test-key", Message: "test-message"}) if err == nil { t.Fatal("expected error when running SetEvent outside of workflow context, but got none") } @@ -1696,7 +1716,7 @@ func TestSetGetEvent(t *testing.T) { t.Run("SetGetEventIdempotency", func(t *testing.T) { // Start the set event workflow - setHandle, err := setEventIdempotencyWf(context.Background(), setEventWorkflowInput{ + setHandle, err := RunAsWorkflow(dbosCtx, setEventIdempotencyWorkflow, setEventWorkflowInput{ Key: "idempotency-key", Message: "idempotency-message", }) @@ -1705,7 +1725,7 @@ func TestSetGetEvent(t *testing.T) { } // Start the get event workflow - getHandle, err := getEventIdempotencyWf(context.Background(), setEventWorkflowInput{ + getHandle, err := RunAsWorkflow(dbosCtx, getEventIdempotencyWorkflow, setEventWorkflowInput{ Key: setHandle.GetWorkflowID(), Message: "idempotency-key", }) @@ -1713,12 +1733,14 @@ func TestSetGetEvent(t *testing.T) { t.Fatalf("failed to start get event idempotency workflow: %v", err) } - // Wait for the get event workflow to signal it has received the event + // Wait for the workflows to signal it has received the event getEventStartIdempotencyEvent.Wait() getEventStartIdempotencyEvent.Clear() + setEventStartIdempotencyEvent.Wait() + setEventStartIdempotencyEvent.Clear() // Attempt recovering both workflows. Each should have exactly 1 step. - recoveredHandles, err := recoverPendingWorkflows(context.Background(), []string{"local"}) + recoveredHandles, err := recoverPendingWorkflows(dbosCtx.(*dbosContext), []string{"local"}) if err != nil { t.Fatalf("failed to recover pending workflows: %v", err) } @@ -1727,9 +1749,10 @@ func TestSetGetEvent(t *testing.T) { } getEventStartIdempotencyEvent.Wait() + setEventStartIdempotencyEvent.Wait() // Verify step counts - setSteps, err := dbos.systemDB.GetWorkflowSteps(context.Background(), setHandle.GetWorkflowID()) + setSteps, err := dbosCtx.(*dbosContext).systemDB.GetWorkflowSteps(context.Background(), setHandle.GetWorkflowID()) if err != nil { t.Fatalf("failed to get steps for set event idempotency workflow: %v", err) } @@ -1737,7 +1760,7 @@ func TestSetGetEvent(t *testing.T) { t.Fatalf("expected 1 step in set event idempotency workflow, got %d", len(setSteps)) } - getSteps, err := dbos.systemDB.GetWorkflowSteps(context.Background(), getHandle.GetWorkflowID()) + getSteps, err := dbosCtx.(*dbosContext).systemDB.GetWorkflowSteps(context.Background(), getHandle.GetWorkflowID()) if err != nil { t.Fatalf("failed to get steps for get event idempotency workflow: %v", err) } @@ -1746,10 +1769,10 @@ func TestSetGetEvent(t *testing.T) { } // Complete the workflows - setEventIdempotencyEvent.Set() + setEvenStopIdempotencyEvent.Set() getEventStopIdempotencyEvent.Set() - setResult, err := setHandle.GetResult(context.Background()) + setResult, err := setHandle.GetResult() if err != nil { t.Fatalf("failed to get result from set event idempotency workflow: %v", err) } @@ -1757,18 +1780,41 @@ func TestSetGetEvent(t *testing.T) { t.Fatalf("expected result to be 'idempotent-set-completed', got '%s'", setResult) } - getResult, err := getHandle.GetResult(context.Background()) + getResult, err := getHandle.GetResult() if err != nil { t.Fatalf("failed to get result from get event idempotency workflow: %v", err) } if getResult != "idempotency-message" { t.Fatalf("expected result to be 'idempotency-message', got '%s'", getResult) } + + // Check the recovered handle returns the same result + for _, recoveredHandle := range recoveredHandles { + if recoveredHandle.GetWorkflowID() == setHandle.GetWorkflowID() { + recoveredSetResult, err := recoveredHandle.GetResult() + if err != nil { + t.Fatalf("failed to get result from recovered set event idempotency workflow: %v", err) + } + if recoveredSetResult != "idempotent-set-completed" { + t.Fatalf("expected recovered result to be 'idempotent-set-completed', got '%s'", recoveredSetResult) + + } + } + if recoveredHandle.GetWorkflowID() == getHandle.GetWorkflowID() { + recoveredGetResult, err := recoveredHandle.GetResult() + if err != nil { + t.Fatalf("failed to get result from recovered get event idempotency workflow: %v", err) + } + if recoveredGetResult != "idempotency-message" { + t.Fatalf("expected recovered result to be 'idempotency-message', got '%s'", recoveredGetResult) + } + } + } }) t.Run("ConcurrentGetEvent", func(t *testing.T) { // Set event - setHandle, err := setEventWf(context.Background(), setEventWorkflowInput{ + setHandle, err := RunAsWorkflow(dbosCtx, setEventWorkflow, setEventWorkflowInput{ Key: "concurrent-event-key", Message: "concurrent-event-message", }) @@ -1777,7 +1823,7 @@ func TestSetGetEvent(t *testing.T) { } // Wait for the set event workflow to complete - _, err = setHandle.GetResult(context.Background()) + _, err = setHandle.GetResult() if err != nil { t.Fatalf("failed to get result from set event workflow: %v", err) } @@ -1789,7 +1835,7 @@ func TestSetGetEvent(t *testing.T) { for range numGoroutines { go func() { defer wg.Done() - res, err := GetEvent[string](context.Background(), WorkflowGetEventInput{ + res, err := GetEvent[string](dbosCtx, WorkflowGetEventInput{ TargetWorkflowID: setHandle.GetWorkflowID(), Key: "concurrent-event-key", Timeout: 10 * time.Second, @@ -1815,13 +1861,12 @@ func TestSetGetEvent(t *testing.T) { } var ( - sleepRecoveryWf = WithWorkflow(sleepRecoveryWorkflow) sleepStartEvent *Event sleepStopEvent *Event ) -func sleepRecoveryWorkflow(ctx context.Context, duration time.Duration) (time.Duration, error) { - result, err := Sleep(ctx, duration) +func sleepRecoveryWorkflow(dbosCtx DBOSContext, duration time.Duration) (time.Duration, error) { + result, err := dbosCtx.Sleep(duration) if err != nil { return 0, err } @@ -1832,7 +1877,8 @@ func sleepRecoveryWorkflow(ctx context.Context, duration time.Duration) (time.Du } func TestSleep(t *testing.T) { - setupDBOS(t) + dbosCtx := setupDBOS(t) + RegisterWorkflow(dbosCtx, sleepRecoveryWorkflow) t.Run("SleepDurableRecovery", func(t *testing.T) { sleepStartEvent = NewEvent() @@ -1841,7 +1887,7 @@ func TestSleep(t *testing.T) { // Start a workflow that sleeps for 2 seconds then blocks sleepDuration := 2 * time.Second - handle, err := sleepRecoveryWf(context.Background(), sleepDuration) + handle, err := RunAsWorkflow(dbosCtx, sleepRecoveryWorkflow, sleepDuration) if err != nil { t.Fatalf("failed to start sleep recovery workflow: %v", err) } @@ -1851,7 +1897,7 @@ func TestSleep(t *testing.T) { // Run the workflow again and check the return time was less than the durable sleep startTime := time.Now() - _, err = sleepRecoveryWf(context.Background(), sleepDuration, WithWorkflowID(handle.GetWorkflowID())) + _, err = RunAsWorkflow(dbosCtx, sleepRecoveryWorkflow, sleepDuration, WithWorkflowID(handle.GetWorkflowID())) if err != nil { t.Fatalf("failed to start second sleep recovery workflow: %v", err) } @@ -1864,7 +1910,7 @@ func TestSleep(t *testing.T) { } // Verify the sleep step was recorded correctly - steps, err := dbos.systemDB.GetWorkflowSteps(context.Background(), handle.GetWorkflowID()) + steps, err := dbosCtx.(*dbosContext).systemDB.GetWorkflowSteps(context.Background(), handle.GetWorkflowID()) if err != nil { t.Fatalf("failed to get workflow steps: %v", err) } @@ -1883,13 +1929,16 @@ func TestSleep(t *testing.T) { } sleepStopEvent.Set() + + _, err = handle.GetResult() + if err != nil { + t.Fatalf("failed to get sleep workflow result: %v", err) + } }) t.Run("SleepCannotBeCalledOutsideWorkflow", func(t *testing.T) { - ctx := context.Background() - // Attempt to call Sleep outside of a workflow context - _, err := Sleep(ctx, 1*time.Second) + _, err := dbosCtx.Sleep(1 * time.Second) if err == nil { t.Fatal("expected error when calling Sleep outside of workflow context, but got none") }