Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions dbos/admin_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"log/slog"
"net/http"
"sync"
"sync/atomic"
"time"
)
Expand All @@ -27,7 +28,6 @@ const (
_WORKFLOW_FORK_PATTERN = "POST /workflows/{id}/fork"

_ADMIN_SERVER_READ_HEADER_TIMEOUT = 5 * time.Second
_ADMIN_SERVER_SHUTDOWN_TIMEOUT = 10 * time.Second
)

// listWorkflowsRequest represents the request structure for listing workflows
Expand Down Expand Up @@ -103,6 +103,7 @@ type adminServer struct {
logger *slog.Logger
port int
isDeactivated atomic.Int32
wg sync.WaitGroup
}

// toListWorkflowResponse converts a WorkflowStatus to a map with all time fields in UTC
Expand Down Expand Up @@ -226,7 +227,7 @@ func newAdminServer(ctx *dbosContext, port int) *adminServer {
mux.HandleFunc(_DEACTIVATE_PATTERN, func(w http.ResponseWriter, r *http.Request) {
if as.isDeactivated.CompareAndSwap(0, 1) {
ctx.logger.Info("Deactivating DBOS executor", "executor_id", ctx.executorID, "app_version", ctx.applicationVersion)
// TODO: Stop queue runner, workflow scheduler, etc
ctx.Cancel(1 * time.Minute) // Cancel context, queue runner, and workflow scheduler
}

w.Header().Set("Content-Type", "text/plain")
Expand Down Expand Up @@ -532,7 +533,9 @@ func newAdminServer(ctx *dbosContext, port int) *adminServer {
func (as *adminServer) Start() error {
as.logger.Info("Starting admin server", "port", as.port)

as.wg.Add(1)
go func() {
defer as.wg.Done()
if err := as.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
as.logger.Error("Admin server error", "error", err)
}
Expand All @@ -541,18 +544,30 @@ func (as *adminServer) Start() error {
return nil
}

func (as *adminServer) Shutdown(ctx context.Context) error {
func (as *adminServer) Shutdown(timeout time.Duration) error {
as.logger.Info("Shutting down admin server")

// Note: consider moving the grace period to DBOSContext.Shutdown()
ctx, cancel := context.WithTimeout(ctx, _ADMIN_SERVER_SHUTDOWN_TIMEOUT)
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()

if err := as.server.Shutdown(ctx); err != nil {
as.logger.Error("Admin server shutdown error", "error", err)
return fmt.Errorf("failed to shutdown admin server: %w", err)
}

as.logger.Info("Admin server shutdown complete")
// Wait for the server goroutine to return
done := make(chan struct{})
go func() {
as.wg.Wait()
close(done)
}()

select {
case <-done:
as.logger.Info("Admin server shutdown complete")
case <-ctx.Done():
as.logger.Warn("Admin server shutdown timed out")
}

Comment on lines +562 to +574
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It turns out that Shutdown on http.Server is non blocking and graceful, i.e., it doesn't terminate existing connections. Waiting on the goroutine that started it to be done is a clean(er) way to wait for the server's full termination.

return nil
}
198 changes: 189 additions & 9 deletions dbos/admin_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"io"
"net/http"
"strings"
"sync/atomic"
"testing"
"time"

Expand All @@ -29,15 +30,20 @@ func TestAdminServer(t *testing.T) {
// Ensure cleanup
defer func() {
if ctx != nil {
ctx.Cancel()
ctx.Shutdown(1 * time.Minute)
}
}()

// Give time for any startup processes
time.Sleep(100 * time.Millisecond)

// Verify admin server is not running
client := &http.Client{Timeout: 1 * time.Second}
client := &http.Client{
Timeout: 1 * time.Second,
Transport: &http.Transport{
DisableKeepAlives: true,
},
}
_, err = client.Get(fmt.Sprintf("http://localhost:3001/%s", strings.TrimPrefix(_HEALTHCHECK_PATTERN, "GET /")))
require.Error(t, err, "Expected request to fail when admin server is not started")

Expand Down Expand Up @@ -65,7 +71,7 @@ func TestAdminServer(t *testing.T) {
// Ensure cleanup
defer func() {
if ctx != nil {
ctx.Cancel()
ctx.Shutdown(1 * time.Minute)
}
}()

Expand All @@ -78,7 +84,12 @@ func TestAdminServer(t *testing.T) {
exec := ctx.(*dbosContext)
require.NotNil(t, exec.adminServer, "Expected admin server to be created in DBOS instance")

client := &http.Client{Timeout: 5 * time.Second}
client := &http.Client{
Timeout: 5 * time.Second,
Transport: &http.Transport{
DisableKeepAlives: true,
},
}

type adminServerTestCase struct {
name string
Expand Down Expand Up @@ -252,14 +263,19 @@ func TestAdminServer(t *testing.T) {
// Ensure cleanup
defer func() {
if ctx != nil {
ctx.Cancel()
ctx.Shutdown(1 * time.Minute)
}
}()

// Give the server a moment to start
time.Sleep(100 * time.Millisecond)

client := &http.Client{Timeout: 5 * time.Second}
client := &http.Client{
Timeout: 5 * time.Second,
Transport: &http.Transport{
DisableKeepAlives: true,
},
}
endpoint := fmt.Sprintf("http://localhost:3001/%s", strings.TrimPrefix(_WORKFLOWS_PATTERN, "POST /"))

// Create workflows with different input/output types
Expand Down Expand Up @@ -379,11 +395,16 @@ func TestAdminServer(t *testing.T) {
// Ensure cleanup
defer func() {
if ctx != nil {
ctx.Cancel()
ctx.Shutdown(1 * time.Minute)
}
}()

client := &http.Client{Timeout: 5 * time.Second}
client := &http.Client{
Timeout: 5 * time.Second,
Transport: &http.Transport{
DisableKeepAlives: true,
},
}
endpoint := fmt.Sprintf("http://localhost:3001/%s", strings.TrimPrefix(_WORKFLOWS_PATTERN, "POST /"))

// Create first workflow
Expand Down Expand Up @@ -530,8 +551,167 @@ func TestAdminServer(t *testing.T) {
}
assert.True(t, foundIDs4[workflowID1], "Expected to find first workflow ID in empty body results")
assert.True(t, foundIDs4[workflowID2], "Expected to find second workflow ID in empty body results")
})

t.Run("TestDeactivate", func(t *testing.T) {
t.Run("Deactivate stops workflow scheduler", func(t *testing.T) {
resetTestDatabase(t, databaseURL)
ctx, err := NewDBOSContext(Config{
DatabaseURL: databaseURL,
AppName: "test-app",
AdminServer: true,
})
require.NoError(t, err)

// Track scheduled workflow executions
var executionCount atomic.Int32

// Register a scheduled workflow that runs every second
RegisterWorkflow(ctx, func(dbosCtx DBOSContext, scheduledTime time.Time) (string, error) {
executionCount.Add(1)
return fmt.Sprintf("executed at %v", scheduledTime), nil
}, WithSchedule("* * * * * *")) // Every second

err = ctx.Launch()
require.NoError(t, err)

client := &http.Client{
Timeout: 5 * time.Second,
Transport: &http.Transport{
DisableKeepAlives: true,
},
}

// Ensure cleanup
defer func() {
if ctx != nil {
ctx.Shutdown(1 * time.Minute)
}
if client.Transport != nil {
client.Transport.(*http.Transport).CloseIdleConnections()
}
}()

// Wait for 2-3 executions to verify scheduler is running
require.Eventually(t, func() bool {
return executionCount.Load() >= 2
}, 3*time.Second, 100*time.Millisecond, "Expected at least 2 scheduled workflow executions")

// Call deactivate endpoint
endpoint := fmt.Sprintf("http://localhost:3001/%s", strings.TrimPrefix(_DEACTIVATE_PATTERN, "GET /"))
req, err := http.NewRequest("GET", endpoint, nil)
require.NoError(t, err, "Failed to create deactivate request")

resp, err := client.Do(req)
require.NoError(t, err, "Failed to call deactivate endpoint")
defer resp.Body.Close()

// Verify endpoint returned 200 OK
assert.Equal(t, http.StatusOK, resp.StatusCode, "Expected 200 OK from deactivate endpoint")

// Verify response body
body, err := io.ReadAll(resp.Body)
require.NoError(t, err, "Failed to read response body")
assert.Equal(t, "deactivated", string(body), "Expected 'deactivated' response body")

// Record count after deactivate and wait
countAfterDeactivate := executionCount.Load()
time.Sleep(4 * time.Second) // Wait long enough for multiple executions if scheduler was still running

// Verify no new executions occurred
finalCount := executionCount.Load()
assert.Equal(t, countAfterDeactivate, finalCount,
"Expected no new scheduled workflows after deactivate (had %d before, %d after)",
countAfterDeactivate, finalCount)
})

t.Run("Deactivate stops queue runner", func(t *testing.T) {
resetTestDatabase(t, databaseURL)
ctx, err := NewDBOSContext(Config{
DatabaseURL: databaseURL,
AppName: "test-app",
AdminServer: true,
})
require.NoError(t, err)

return // Skip the normal test flow
// Create a test queue
testQueue := NewWorkflowQueue(ctx, "test-deactivate-queue")

// Track workflow executions with atomic counter
var executionCount atomic.Int32

// Register a simple workflow
testWorkflow := func(dbosCtx DBOSContext, input string) (string, error) {
executionCount.Add(1)
return "completed-" + input, nil
}
RegisterWorkflow(ctx, testWorkflow)

err = ctx.Launch()
require.NoError(t, err)

client := &http.Client{
Timeout: 5 * time.Second,
Transport: &http.Transport{
DisableKeepAlives: true,
},
}

// Ensure cleanup
defer func() {
if ctx != nil {
ctx.Shutdown(1 * time.Minute)
}
if client.Transport != nil {
client.Transport.(*http.Transport).CloseIdleConnections()
}
}()

// Enqueue and complete one workflow to verify queue runner is working
handle1, err := RunAsWorkflow(ctx, testWorkflow, "initial-test", WithQueue(testQueue.Name))
require.NoError(t, err, "Failed to enqueue initial workflow")

result1, err := handle1.GetResult()
require.NoError(t, err, "Failed to get initial workflow result")
assert.Equal(t, "completed-initial-test", result1)
assert.Equal(t, int32(1), executionCount.Load(), "Expected one execution before deactivate")

// Call deactivate endpoint to stop the queue runner
endpoint := fmt.Sprintf("http://localhost:3001/%s", strings.TrimPrefix(_DEACTIVATE_PATTERN, "GET /"))
req, err := http.NewRequest("GET", endpoint, nil)
require.NoError(t, err, "Failed to create deactivate request")

resp, err := client.Do(req)
require.NoError(t, err, "Failed to call deactivate endpoint")
defer resp.Body.Close()

// Verify endpoint returned 200 OK
assert.Equal(t, http.StatusOK, resp.StatusCode, "Expected 200 OK from deactivate endpoint")

// Verify response body
body, err := io.ReadAll(resp.Body)
require.NoError(t, err, "Failed to read response body")
assert.Equal(t, "deactivated", string(body), "Expected 'deactivated' response body")

// After deactivate is called, the context is cancelled
// We can enqueue more workflows but they won't be dequeued
handle2, err := RunAsWorkflow(ctx, testWorkflow, "post-deactivate-test", WithQueue(testQueue.Name))
require.NoError(t, err, "Failed to enqueue post-deactivate workflow")

countAfterDeactivate := executionCount.Load()

// Wait to see if any phantom executions happen
time.Sleep(2 * time.Second)

// Verify no additional workflows executed (should still be 1)
finalCount := executionCount.Load()
assert.Equal(t, countAfterDeactivate, finalCount,
"Expected no additional workflow executions after deactivate")

handle2Status, err := handle2.GetStatus()
require.NoError(t, err, "Failed to get status of post-deactivate workflow")
assert.Equal(t, WorkflowStatusEnqueued, handle2Status.Status, "Expected post-deactivate workflow to be pending")
})
})

}
Expand Down
Loading
Loading