diff --git a/dbos/admin_server.go b/dbos/admin_server.go index d25840fc..85ed8205 100644 --- a/dbos/admin_server.go +++ b/dbos/admin_server.go @@ -6,6 +6,7 @@ import ( "fmt" "log/slog" "net/http" + "sync" "sync/atomic" "time" ) @@ -27,7 +28,6 @@ const ( _WORKFLOW_FORK_PATTERN = "POST /workflows/{id}/fork" _ADMIN_SERVER_READ_HEADER_TIMEOUT = 5 * time.Second - _ADMIN_SERVER_SHUTDOWN_TIMEOUT = 10 * time.Second ) // listWorkflowsRequest represents the request structure for listing workflows @@ -103,6 +103,7 @@ type adminServer struct { logger *slog.Logger port int isDeactivated atomic.Int32 + wg sync.WaitGroup } // toListWorkflowResponse converts a WorkflowStatus to a map with all time fields in UTC @@ -226,7 +227,10 @@ func newAdminServer(ctx *dbosContext, port int) *adminServer { mux.HandleFunc(_DEACTIVATE_PATTERN, func(w http.ResponseWriter, r *http.Request) { if as.isDeactivated.CompareAndSwap(0, 1) { ctx.logger.Info("Deactivating DBOS executor", "executor_id", ctx.executorID, "app_version", ctx.applicationVersion) - // TODO: Stop queue runner, workflow scheduler, etc + // Stop the workflow scheduler. Note we don't wait for running jobs to complete + if ctx.workflowScheduler != nil { + ctx.workflowScheduler.Stop() + } } w.Header().Set("Content-Type", "text/plain") @@ -532,7 +536,9 @@ func newAdminServer(ctx *dbosContext, port int) *adminServer { func (as *adminServer) Start() error { as.logger.Info("Starting admin server", "port", as.port) + as.wg.Add(1) go func() { + defer as.wg.Done() if err := as.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { as.logger.Error("Admin server error", "error", err) } @@ -541,11 +547,10 @@ func (as *adminServer) Start() error { return nil } -func (as *adminServer) Shutdown(ctx context.Context) error { +func (as *adminServer) Shutdown(timeout time.Duration) error { as.logger.Info("Shutting down admin server") - // Note: consider moving the grace period to DBOSContext.Shutdown() - ctx, cancel := context.WithTimeout(ctx, _ADMIN_SERVER_SHUTDOWN_TIMEOUT) + ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() if err := as.server.Shutdown(ctx); err != nil { @@ -553,6 +558,19 @@ func (as *adminServer) Shutdown(ctx context.Context) error { return fmt.Errorf("failed to shutdown admin server: %w", err) } - as.logger.Info("Admin server shutdown complete") + // Wait for the server goroutine to return + done := make(chan struct{}) + go func() { + as.wg.Wait() + close(done) + }() + + select { + case <-done: + as.logger.Info("Admin server shutdown complete") + case <-ctx.Done(): + as.logger.Warn("Admin server shutdown timed out") + } + return nil } diff --git a/dbos/admin_server_test.go b/dbos/admin_server_test.go index da40359e..25e2ff8c 100644 --- a/dbos/admin_server_test.go +++ b/dbos/admin_server_test.go @@ -7,6 +7,7 @@ import ( "io" "net/http" "strings" + "sync/atomic" "testing" "time" @@ -29,7 +30,7 @@ func TestAdminServer(t *testing.T) { // Ensure cleanup defer func() { if ctx != nil { - ctx.Cancel() + ctx.Shutdown(1 * time.Minute) } }() @@ -65,7 +66,7 @@ func TestAdminServer(t *testing.T) { // Ensure cleanup defer func() { if ctx != nil { - ctx.Cancel() + ctx.Shutdown(1 * time.Minute) } }() @@ -252,7 +253,7 @@ func TestAdminServer(t *testing.T) { // Ensure cleanup defer func() { if ctx != nil { - ctx.Cancel() + ctx.Shutdown(1 * time.Minute) } }() @@ -379,7 +380,7 @@ func TestAdminServer(t *testing.T) { // Ensure cleanup defer func() { if ctx != nil { - ctx.Cancel() + ctx.Shutdown(1 * time.Minute) } }() @@ -530,8 +531,74 @@ func TestAdminServer(t *testing.T) { } assert.True(t, foundIDs4[workflowID1], "Expected to find first workflow ID in empty body results") assert.True(t, foundIDs4[workflowID2], "Expected to find second workflow ID in empty body results") + }) + + t.Run("TestDeactivate", func(t *testing.T) { + t.Run("Deactivate stops workflow scheduler", func(t *testing.T) { + resetTestDatabase(t, databaseURL) + ctx, err := NewDBOSContext(Config{ + DatabaseURL: databaseURL, + AppName: "test-app", + AdminServer: true, + }) + require.NoError(t, err) + + // Track scheduled workflow executions + var executionCount atomic.Int32 + + // Register a scheduled workflow that runs every second + RegisterWorkflow(ctx, func(dbosCtx DBOSContext, scheduledTime time.Time) (string, error) { + executionCount.Add(1) + return fmt.Sprintf("executed at %v", scheduledTime), nil + }, WithSchedule("* * * * * *")) // Every second - return // Skip the normal test flow + err = ctx.Launch() + require.NoError(t, err) + + client := &http.Client{Timeout: 5 * time.Second} + + // Ensure cleanup + defer func() { + if ctx != nil { + ctx.Shutdown(1 * time.Minute) + } + if client.Transport != nil { + client.Transport.(*http.Transport).CloseIdleConnections() + } + }() + + // Wait for 2-3 executions to verify scheduler is running + require.Eventually(t, func() bool { + return executionCount.Load() >= 2 + }, 3*time.Second, 100*time.Millisecond, "Expected at least 2 scheduled workflow executions") + + // Call deactivate endpoint + endpoint := fmt.Sprintf("http://localhost:3001/%s", strings.TrimPrefix(_DEACTIVATE_PATTERN, "GET /")) + req, err := http.NewRequest("GET", endpoint, nil) + require.NoError(t, err, "Failed to create deactivate request") + + resp, err := client.Do(req) + require.NoError(t, err, "Failed to call deactivate endpoint") + defer resp.Body.Close() + + // Verify endpoint returned 200 OK + assert.Equal(t, http.StatusOK, resp.StatusCode, "Expected 200 OK from deactivate endpoint") + + // Verify response body + body, err := io.ReadAll(resp.Body) + require.NoError(t, err, "Failed to read response body") + assert.Equal(t, "deactivated", string(body), "Expected 'deactivated' response body") + + // Record count after deactivate and wait + countAfterDeactivate := executionCount.Load() + time.Sleep(4 * time.Second) // Wait long enough for multiple executions if scheduler was still running + + // Verify no new executions occurred + finalCount := executionCount.Load() + assert.Equal(t, countAfterDeactivate, finalCount, + "Expected no new scheduled workflows after deactivate (had %d before, %d after)", + countAfterDeactivate, finalCount) + }) }) } diff --git a/dbos/dbos.go b/dbos/dbos.go index f3c35e08..8b08bfad 100644 --- a/dbos/dbos.go +++ b/dbos/dbos.go @@ -71,8 +71,8 @@ type DBOSContext interface { context.Context // Context Lifecycle - Launch() error // Launch the DBOS runtime including system database, queues, admin server, and workflow recovery - Cancel() // Gracefully shutdown the DBOS runtime, waiting for workflows to complete and cleaning up resources + Launch() error // Launch the DBOS runtime including system database, queues, admin server, and workflow recovery + Shutdown(timeout time.Duration) // Gracefully shutdown all DBOS runtime components with ordered cleanup sequence // Workflow operations RunAsStep(_ DBOSContext, fn StepFunc, opts ...StepOption) (any, error) // Execute a function as a durable step within a workflow @@ -239,7 +239,7 @@ func (c *dbosContext) GetApplicationID() string { } // NewDBOSContext creates a new DBOS context with the provided configuration. -// The context must be launched with Launch() before use and should be shut down with Cancel(). +// The context must be launched with Launch() before use and should be shut down with Shutdown(). // This function initializes the DBOS system database, sets up the queue sub-system, // and prepares the workflow registry. // @@ -253,7 +253,7 @@ func (c *dbosContext) GetApplicationID() string { // if err != nil { // log.Fatal(err) // } -// defer ctx.Cancel() +// defer ctx.Shutdown(30*time.Second) // // if err := ctx.Launch(); err != nil { // log.Fatal(err) @@ -372,62 +372,86 @@ func (c *dbosContext) Launch() error { return nil } -// Cancel gracefully shuts down the DBOS runtime by canceling the context, waiting for -// all workflows to complete, and cleaning up system resources including the database -// connection pool, queue runner, workflow scheduler, and admin server. -// All workflows and steps contexts will be canceled, which one can check using their context's Done() method. +// Shutdown gracefully shuts down the DBOS runtime by performing a complete, ordered cleanup +// of all system components. The shutdown sequence includes: // -// This method blocks until all workflows finish and all resources are properly cleaned up. -// It should be called when the application is shutting down to ensure data consistency. -func (c *dbosContext) Cancel() { +// 1. Calls Cancel to stop workflows and cancel the context +// 2. Waits for the queue runner to complete processing +// 3. Stops the workflow scheduler and waits for scheduled jobs to finish +// 4. Shuts down the system database connection pool and notification listener +// 5. Shuts down the admin server +// 6. Marks the context as not launched +// +// Each step respects the provided timeout. If any component doesn't shut down within the timeout, +// a warning is logged and the shutdown continues to the next component. +// +// Shutdown is a permanent operation and should be called when the application is terminating. +func (c *dbosContext) Shutdown(timeout time.Duration) { c.logger.Info("Shutting down DBOS context") // Cancel the context to signal all resources to stop - c.ctxCancelFunc(errors.New("DBOS shutdown initiated")) + c.ctxCancelFunc(errors.New("DBOS cancellation initiated")) // Wait for all workflows to finish c.logger.Info("Waiting for all workflows to finish") - c.workflowsWg.Wait() - c.logger.Info("All workflows completed") + done := make(chan struct{}) + go func() { + c.workflowsWg.Wait() + close(done) + }() + select { + case <-done: + c.logger.Info("All workflows completed") + case <-time.After(timeout): + // For now just log a warning: eventually we might want Cancel to return an error. + c.logger.Warn("Timeout waiting for workflows to complete", "timeout", timeout) + } - // Close the pool and the notification listener if started - if c.systemDB != nil { - c.logger.Info("Shutting down system database") - c.systemDB.shutdown(c) - c.systemDB = nil + // Wait for queue runner to finish + if c.queueRunner != nil && c.launched.Load() { + c.logger.Info("Waiting for queue runner to complete") + select { + case <-c.queueRunner.completionChan: + c.logger.Info("Queue runner completed") + c.queueRunner = nil + case <-time.After(timeout): + c.logger.Warn("Timeout waiting for queue runner to complete", "timeout", timeout) + } } - if c.launched.Load() { - // Wait for queue runner to finish - <-c.queueRunner.completionChan - c.logger.Info("Queue runner completed") - - if c.workflowScheduler != nil { - c.logger.Info("Stopping workflow scheduler") - ctx := c.workflowScheduler.Stop() - // Wait for all running jobs to complete with 5-second timeout - timeoutCtx, cancel := context.WithTimeout(ctx, 5*time.Second) - defer cancel() - - select { - case <-ctx.Done(): - c.logger.Info("All scheduled jobs completed") - case <-timeoutCtx.Done(): - c.logger.Warn("Timeout waiting for jobs to complete. Moving on", "timeout", "5s") - } + // Stop the workflow scheduler and wait until all scheduled workflows are done + if c.workflowScheduler != nil && c.launched.Load() { + c.logger.Info("Stopping workflow scheduler") + ctx := c.workflowScheduler.Stop() + + select { + case <-ctx.Done(): + c.logger.Info("All scheduled jobs completed") + c.workflowScheduler = nil + case <-time.After(timeout): + c.logger.Warn("Timeout waiting for jobs to complete. Moving on", "timeout", timeout) } + } - if c.adminServer != nil { - c.logger.Info("Shutting down admin server") - err := c.adminServer.Shutdown(c) - if err != nil { - c.logger.Error("Failed to shutdown admin server", "error", err) - } else { - c.logger.Info("Admin server shutdown complete") - } - c.adminServer = nil + // Shutdown the admin server + if c.adminServer != nil && c.launched.Load() { + c.logger.Info("Shutting down admin server") + err := c.adminServer.Shutdown(timeout) + if err != nil { + c.logger.Error("Failed to shutdown admin server", "error", err) + } else { + c.logger.Info("Admin server shutdown complete") } + c.adminServer = nil } + + // Close the system database + if c.systemDB != nil { + c.logger.Info("Shutting down system database") + c.systemDB.shutdown(c, timeout) + c.systemDB = nil + } + c.launched.Store(false) } diff --git a/dbos/dbos_test.go b/dbos/dbos_test.go index f65c314c..98ef59b4 100644 --- a/dbos/dbos_test.go +++ b/dbos/dbos_test.go @@ -2,6 +2,7 @@ package dbos import ( "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -21,7 +22,7 @@ func TestConfigValidationErrorTypes(t *testing.T) { require.NoError(t, err) defer func() { if ctx != nil { - ctx.Cancel() + ctx.Shutdown(1*time.Minute) } }() // Clean up executor diff --git a/dbos/logger_test.go b/dbos/logger_test.go index 4921b294..dd305dbd 100644 --- a/dbos/logger_test.go +++ b/dbos/logger_test.go @@ -4,6 +4,7 @@ import ( "bytes" "log/slog" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -22,7 +23,7 @@ func TestLogger(t *testing.T) { require.NoError(t, err) t.Cleanup(func() { if dbosCtx != nil { - dbosCtx.Cancel() + dbosCtx.Shutdown(10*time.Second) } }) @@ -55,7 +56,7 @@ func TestLogger(t *testing.T) { require.NoError(t, err) t.Cleanup(func() { if dbosCtx != nil { - dbosCtx.Cancel() + dbosCtx.Shutdown(10*time.Second) } }) diff --git a/dbos/queue.go b/dbos/queue.go index 88035ff2..290ba7f3 100644 --- a/dbos/queue.go +++ b/dbos/queue.go @@ -158,7 +158,7 @@ func newQueueRunner() *queueRunner { jitterMin: 0.95, jitterMax: 1.05, workflowQueueRegistry: make(map[string]WorkflowQueue), - completionChan: make(chan struct{}), + completionChan: make(chan struct{}, 1), } } diff --git a/dbos/system_database.go b/dbos/system_database.go index 94496daa..0086d589 100644 --- a/dbos/system_database.go +++ b/dbos/system_database.go @@ -26,7 +26,7 @@ import ( type systemDatabase interface { // SysDB management launch(ctx context.Context) - shutdown(ctx context.Context) + shutdown(ctx context.Context, timeout time.Duration) resetSystemDB(ctx context.Context) error // Workflows @@ -235,7 +235,7 @@ func (s *sysDB) launch(ctx context.Context) { s.launched = true } -func (s *sysDB) shutdown(ctx context.Context) { +func (s *sysDB) shutdown(ctx context.Context, timeout time.Duration) { s.logger.Info("DBOS: Closing system database connection pool") if s.pool != nil { s.pool.Close() @@ -251,7 +251,12 @@ func (s *sysDB) shutdown(ctx context.Context) { if s.launched { // Wait for the notification loop to exit - <-s.notificationLoopDone + s.logger.Info("DBOS: Waiting for notification listener loop to finish") + select { + case <-s.notificationLoopDone: + case <-time.After(timeout): + s.logger.Warn("DBOS: Notification listener loop did not finish in time", "timeout", timeout) + } } s.notificationsMap.Clear() diff --git a/dbos/utils_test.go b/dbos/utils_test.go index fc82954c..df7e0412 100644 --- a/dbos/utils_test.go +++ b/dbos/utils_test.go @@ -70,7 +70,7 @@ func setupDBOS(t *testing.T, dropDB bool, checkLeaks bool) DBOSContext { fmt.Println("Cleaning up DBOS instance...") dbosCtx.(*dbosContext).logger.Info("Cleaning up DBOS instance...") if dbosCtx != nil { - dbosCtx.Cancel() + dbosCtx.Shutdown(30 * time.Second) // Wait for workflows to finish and shutdown admin server and system database } dbosCtx = nil if checkLeaks { diff --git a/dbos/workflows_test.go b/dbos/workflows_test.go index 19eceac4..f9d6df0f 100644 --- a/dbos/workflows_test.go +++ b/dbos/workflows_test.go @@ -351,7 +351,7 @@ func TestWorkflowsRegistration(t *testing.T) { // Launch DBOS context err := freshCtx.Launch() require.NoError(t, err) - defer freshCtx.Cancel() + defer freshCtx.Shutdown(10 * time.Second) // Attempting to register after launch should panic defer func() {