Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions dbos/admin_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"log/slog"
"net/http"
"sync"
"sync/atomic"
"time"
)
Expand All @@ -27,7 +28,6 @@ const (
_WORKFLOW_FORK_PATTERN = "POST /workflows/{id}/fork"

_ADMIN_SERVER_READ_HEADER_TIMEOUT = 5 * time.Second
_ADMIN_SERVER_SHUTDOWN_TIMEOUT = 10 * time.Second
)

// listWorkflowsRequest represents the request structure for listing workflows
Expand Down Expand Up @@ -103,6 +103,7 @@ type adminServer struct {
logger *slog.Logger
port int
isDeactivated atomic.Int32
wg sync.WaitGroup
}

// toListWorkflowResponse converts a WorkflowStatus to a map with all time fields in UTC
Expand Down Expand Up @@ -226,7 +227,10 @@ func newAdminServer(ctx *dbosContext, port int) *adminServer {
mux.HandleFunc(_DEACTIVATE_PATTERN, func(w http.ResponseWriter, r *http.Request) {
if as.isDeactivated.CompareAndSwap(0, 1) {
ctx.logger.Info("Deactivating DBOS executor", "executor_id", ctx.executorID, "app_version", ctx.applicationVersion)
// TODO: Stop queue runner, workflow scheduler, etc
// Stop the workflow scheduler. Note we don't wait for running jobs to complete
if ctx.workflowScheduler != nil {
ctx.workflowScheduler.Stop()
}
}

w.Header().Set("Content-Type", "text/plain")
Expand Down Expand Up @@ -532,7 +536,9 @@ func newAdminServer(ctx *dbosContext, port int) *adminServer {
func (as *adminServer) Start() error {
as.logger.Info("Starting admin server", "port", as.port)

as.wg.Add(1)
go func() {
defer as.wg.Done()
if err := as.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
as.logger.Error("Admin server error", "error", err)
}
Expand All @@ -541,18 +547,30 @@ func (as *adminServer) Start() error {
return nil
}

func (as *adminServer) Shutdown(ctx context.Context) error {
func (as *adminServer) Shutdown(timeout time.Duration) error {
as.logger.Info("Shutting down admin server")

// Note: consider moving the grace period to DBOSContext.Shutdown()
ctx, cancel := context.WithTimeout(ctx, _ADMIN_SERVER_SHUTDOWN_TIMEOUT)
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()

if err := as.server.Shutdown(ctx); err != nil {
as.logger.Error("Admin server shutdown error", "error", err)
return fmt.Errorf("failed to shutdown admin server: %w", err)
}

as.logger.Info("Admin server shutdown complete")
// Wait for the server goroutine to return
done := make(chan struct{})
go func() {
as.wg.Wait()
close(done)
}()

select {
case <-done:
as.logger.Info("Admin server shutdown complete")
case <-ctx.Done():
as.logger.Warn("Admin server shutdown timed out")
}

Comment on lines +562 to +574
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It turns out that Shutdown on http.Server is non blocking and graceful, i.e., it doesn't terminate existing connections. Waiting on the goroutine that started it to be done is a clean(er) way to wait for the server's full termination.

return nil
}
77 changes: 72 additions & 5 deletions dbos/admin_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"io"
"net/http"
"strings"
"sync/atomic"
"testing"
"time"

Expand All @@ -29,7 +30,7 @@ func TestAdminServer(t *testing.T) {
// Ensure cleanup
defer func() {
if ctx != nil {
ctx.Cancel()
ctx.Shutdown(1 * time.Minute)
}
}()

Expand Down Expand Up @@ -65,7 +66,7 @@ func TestAdminServer(t *testing.T) {
// Ensure cleanup
defer func() {
if ctx != nil {
ctx.Cancel()
ctx.Shutdown(1 * time.Minute)
}
}()

Expand Down Expand Up @@ -252,7 +253,7 @@ func TestAdminServer(t *testing.T) {
// Ensure cleanup
defer func() {
if ctx != nil {
ctx.Cancel()
ctx.Shutdown(1 * time.Minute)
}
}()

Expand Down Expand Up @@ -379,7 +380,7 @@ func TestAdminServer(t *testing.T) {
// Ensure cleanup
defer func() {
if ctx != nil {
ctx.Cancel()
ctx.Shutdown(1 * time.Minute)
}
}()

Expand Down Expand Up @@ -530,8 +531,74 @@ func TestAdminServer(t *testing.T) {
}
assert.True(t, foundIDs4[workflowID1], "Expected to find first workflow ID in empty body results")
assert.True(t, foundIDs4[workflowID2], "Expected to find second workflow ID in empty body results")
})

t.Run("TestDeactivate", func(t *testing.T) {
t.Run("Deactivate stops workflow scheduler", func(t *testing.T) {
resetTestDatabase(t, databaseURL)
ctx, err := NewDBOSContext(Config{
DatabaseURL: databaseURL,
AppName: "test-app",
AdminServer: true,
})
require.NoError(t, err)

// Track scheduled workflow executions
var executionCount atomic.Int32

// Register a scheduled workflow that runs every second
RegisterWorkflow(ctx, func(dbosCtx DBOSContext, scheduledTime time.Time) (string, error) {
executionCount.Add(1)
return fmt.Sprintf("executed at %v", scheduledTime), nil
}, WithSchedule("* * * * * *")) // Every second

return // Skip the normal test flow
err = ctx.Launch()
require.NoError(t, err)

client := &http.Client{Timeout: 5 * time.Second}

// Ensure cleanup
defer func() {
if ctx != nil {
ctx.Shutdown(1 * time.Minute)
}
if client.Transport != nil {
client.Transport.(*http.Transport).CloseIdleConnections()
}
}()

// Wait for 2-3 executions to verify scheduler is running
require.Eventually(t, func() bool {
return executionCount.Load() >= 2
}, 3*time.Second, 100*time.Millisecond, "Expected at least 2 scheduled workflow executions")

// Call deactivate endpoint
endpoint := fmt.Sprintf("http://localhost:3001/%s", strings.TrimPrefix(_DEACTIVATE_PATTERN, "GET /"))
req, err := http.NewRequest("GET", endpoint, nil)
require.NoError(t, err, "Failed to create deactivate request")

resp, err := client.Do(req)
require.NoError(t, err, "Failed to call deactivate endpoint")
defer resp.Body.Close()

// Verify endpoint returned 200 OK
assert.Equal(t, http.StatusOK, resp.StatusCode, "Expected 200 OK from deactivate endpoint")

// Verify response body
body, err := io.ReadAll(resp.Body)
require.NoError(t, err, "Failed to read response body")
assert.Equal(t, "deactivated", string(body), "Expected 'deactivated' response body")

// Record count after deactivate and wait
countAfterDeactivate := executionCount.Load()
time.Sleep(4 * time.Second) // Wait long enough for multiple executions if scheduler was still running

// Verify no new executions occurred
finalCount := executionCount.Load()
assert.Equal(t, countAfterDeactivate, finalCount,
"Expected no new scheduled workflows after deactivate (had %d before, %d after)",
countAfterDeactivate, finalCount)
})
})

}
Expand Down
116 changes: 70 additions & 46 deletions dbos/dbos.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ type DBOSContext interface {
context.Context

// Context Lifecycle
Launch() error // Launch the DBOS runtime including system database, queues, admin server, and workflow recovery
Cancel() // Gracefully shutdown the DBOS runtime, waiting for workflows to complete and cleaning up resources
Launch() error // Launch the DBOS runtime including system database, queues, admin server, and workflow recovery
Shutdown(timeout time.Duration) // Gracefully shutdown all DBOS runtime components with ordered cleanup sequence

// Workflow operations
RunAsStep(_ DBOSContext, fn StepFunc, opts ...StepOption) (any, error) // Execute a function as a durable step within a workflow
Expand Down Expand Up @@ -239,7 +239,7 @@ func (c *dbosContext) GetApplicationID() string {
}

// NewDBOSContext creates a new DBOS context with the provided configuration.
// The context must be launched with Launch() before use and should be shut down with Cancel().
// The context must be launched with Launch() before use and should be shut down with Shutdown().
// This function initializes the DBOS system database, sets up the queue sub-system,
// and prepares the workflow registry.
//
Expand All @@ -253,7 +253,7 @@ func (c *dbosContext) GetApplicationID() string {
// if err != nil {
// log.Fatal(err)
// }
// defer ctx.Cancel()
// defer ctx.Shutdown(30*time.Second)
//
// if err := ctx.Launch(); err != nil {
// log.Fatal(err)
Expand Down Expand Up @@ -372,62 +372,86 @@ func (c *dbosContext) Launch() error {
return nil
}

// Cancel gracefully shuts down the DBOS runtime by canceling the context, waiting for
// all workflows to complete, and cleaning up system resources including the database
// connection pool, queue runner, workflow scheduler, and admin server.
// All workflows and steps contexts will be canceled, which one can check using their context's Done() method.
// Shutdown gracefully shuts down the DBOS runtime by performing a complete, ordered cleanup
// of all system components. The shutdown sequence includes:
//
// This method blocks until all workflows finish and all resources are properly cleaned up.
// It should be called when the application is shutting down to ensure data consistency.
func (c *dbosContext) Cancel() {
// 1. Calls Cancel to stop workflows and cancel the context
// 2. Waits for the queue runner to complete processing
// 3. Stops the workflow scheduler and waits for scheduled jobs to finish
// 4. Shuts down the system database connection pool and notification listener
// 5. Shuts down the admin server
// 6. Marks the context as not launched
//
// Each step respects the provided timeout. If any component doesn't shut down within the timeout,
// a warning is logged and the shutdown continues to the next component.
//
// Shutdown is a permanent operation and should be called when the application is terminating.
func (c *dbosContext) Shutdown(timeout time.Duration) {
c.logger.Info("Shutting down DBOS context")

// Cancel the context to signal all resources to stop
c.ctxCancelFunc(errors.New("DBOS shutdown initiated"))
c.ctxCancelFunc(errors.New("DBOS cancellation initiated"))

// Wait for all workflows to finish
c.logger.Info("Waiting for all workflows to finish")
c.workflowsWg.Wait()
c.logger.Info("All workflows completed")
done := make(chan struct{})
go func() {
c.workflowsWg.Wait()
close(done)
}()
select {
case <-done:
c.logger.Info("All workflows completed")
case <-time.After(timeout):
// For now just log a warning: eventually we might want Cancel to return an error.
c.logger.Warn("Timeout waiting for workflows to complete", "timeout", timeout)
}

// Close the pool and the notification listener if started
if c.systemDB != nil {
c.logger.Info("Shutting down system database")
c.systemDB.shutdown(c)
c.systemDB = nil
// Wait for queue runner to finish
if c.queueRunner != nil && c.launched.Load() {
c.logger.Info("Waiting for queue runner to complete")
select {
case <-c.queueRunner.completionChan:
c.logger.Info("Queue runner completed")
c.queueRunner = nil
case <-time.After(timeout):
c.logger.Warn("Timeout waiting for queue runner to complete", "timeout", timeout)
Comment on lines +415 to +418
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we cannot safely set the pointer to nil in case of timeout: the resource might yet to be accessed, which would cause a panic.

}
}

if c.launched.Load() {
// Wait for queue runner to finish
<-c.queueRunner.completionChan
c.logger.Info("Queue runner completed")

if c.workflowScheduler != nil {
c.logger.Info("Stopping workflow scheduler")
ctx := c.workflowScheduler.Stop()
// Wait for all running jobs to complete with 5-second timeout
timeoutCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()

select {
case <-ctx.Done():
c.logger.Info("All scheduled jobs completed")
case <-timeoutCtx.Done():
c.logger.Warn("Timeout waiting for jobs to complete. Moving on", "timeout", "5s")
}
// Stop the workflow scheduler and wait until all scheduled workflows are done
if c.workflowScheduler != nil && c.launched.Load() {
c.logger.Info("Stopping workflow scheduler")
ctx := c.workflowScheduler.Stop()

select {
case <-ctx.Done():
c.logger.Info("All scheduled jobs completed")
c.workflowScheduler = nil
case <-time.After(timeout):
c.logger.Warn("Timeout waiting for jobs to complete. Moving on", "timeout", timeout)
}
}

if c.adminServer != nil {
c.logger.Info("Shutting down admin server")
err := c.adminServer.Shutdown(c)
if err != nil {
c.logger.Error("Failed to shutdown admin server", "error", err)
} else {
c.logger.Info("Admin server shutdown complete")
}
c.adminServer = nil
// Shutdown the admin server
if c.adminServer != nil && c.launched.Load() {
c.logger.Info("Shutting down admin server")
err := c.adminServer.Shutdown(timeout)
if err != nil {
c.logger.Error("Failed to shutdown admin server", "error", err)
} else {
c.logger.Info("Admin server shutdown complete")
}
c.adminServer = nil
}

// Close the system database
if c.systemDB != nil {
c.logger.Info("Shutting down system database")
c.systemDB.shutdown(c, timeout)
c.systemDB = nil
}

c.launched.Store(false)
}

Expand Down
3 changes: 2 additions & 1 deletion dbos/dbos_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package dbos

import (
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand All @@ -21,7 +22,7 @@ func TestConfigValidationErrorTypes(t *testing.T) {
require.NoError(t, err)
defer func() {
if ctx != nil {
ctx.Cancel()
ctx.Shutdown(1*time.Minute)
}
}() // Clean up executor

Expand Down
Loading
Loading