Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions dbos/admin-server.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
package dbos

import (
"context"
"encoding/json"
"fmt"
"net/http"
"time"
)

type AdminServer struct {
server *http.Server
}

func NewAdminServer(port int) *AdminServer {
mux := http.NewServeMux()

// Health endpoint
mux.HandleFunc("/dbos-healthz", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"status":"healthy"}`))
})

// Recovery endpoint
mux.HandleFunc("/dbos-workflow-recovery", func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}

var executorIDs []string
if err := json.NewDecoder(r.Body).Decode(&executorIDs); err != nil {
http.Error(w, "Invalid JSON body", http.StatusBadRequest)
return
}

getLogger().Info("Recovering workflows for executors", "executors", executorIDs)

handles, err := recoverPendingWorkflows(r.Context(), executorIDs)
if err != nil {
getLogger().Error("Error recovering workflows", "error", err)
http.Error(w, fmt.Sprintf("Recovery failed: %v", err), http.StatusInternalServerError)
return
}

// Extract workflow IDs from handles
workflowIDs := make([]string, len(handles))
for i, handle := range handles {
workflowIDs[i] = handle.GetWorkflowID()
}

w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(workflowIDs); err != nil {
getLogger().Error("Error encoding response", "error", err)
}
})

server := &http.Server{
Addr: fmt.Sprintf(":%d", port),
Handler: mux,
}

return &AdminServer{
server: server,
}
}

func (as *AdminServer) Start() error {
getLogger().Info("Starting admin server", "port", 3001)

go func() {
if err := as.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
getLogger().Error("Admin server error", "error", err)
}
}()

return nil
}

func (as *AdminServer) Shutdown() error {
getLogger().Info("Shutting down admin server")

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

if err := as.server.Shutdown(ctx); err != nil {
getLogger().Error("Admin server shutdown error", "error", err)
return fmt.Errorf("failed to shutdown admin server: %w", err)
}

getLogger().Info("Admin server shutdown complete")
return nil
}
176 changes: 176 additions & 0 deletions dbos/admin-server_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
package dbos

import (
"bytes"
"encoding/json"
"io"
"net/http"
"os"
"strings"
"testing"
"time"
)

func TestAdminServer(t *testing.T) {
// Skip if database is not available
databaseURL := os.Getenv("DBOS_DATABASE_URL")
if databaseURL == "" && os.Getenv("PGPASSWORD") == "" {
t.Skip("Database not available (DBOS_DATABASE_URL and PGPASSWORD not set), skipping DBOS integration tests")
}

t.Run("Admin server is not started without WithAdminServer option", func(t *testing.T) {
// Ensure clean state
if dbos != nil {
Shutdown()
}

// Launch DBOS without admin server option
err := Launch()
if err != nil {
t.Skipf("Failed to launch DBOS (database likely not available): %v", err)
}

// Ensure cleanup
defer Shutdown()

// Give time for any startup processes
time.Sleep(100 * time.Millisecond)

// Verify admin server is not running
client := &http.Client{Timeout: 1 * time.Second}
_, err = client.Get("http://localhost:3001/dbos-healthz")
if err == nil {
t.Error("Expected request to fail when admin server is not started, but it succeeded")
}

// Verify the DBOS executor doesn't have an admin server instance
if dbos == nil {
t.Fatal("Expected DBOS instance to be created")
}

if dbos.adminServer != nil {
t.Error("Expected admin server to be nil when not configured")
}
})

t.Run("Admin server endpoints", func(t *testing.T) {
// Ensure clean state
if dbos != nil {
Shutdown()
}

// Launch DBOS with admin server once for all endpoint tests
err := Launch(WithAdminServer())
if err != nil {
t.Skipf("Failed to launch DBOS with admin server (database likely not available): %v", err)
}

// Ensure cleanup
defer Shutdown()

// Give the server a moment to start
time.Sleep(100 * time.Millisecond)

// Verify the DBOS executor has an admin server instance
if dbos == nil {
t.Fatal("Expected DBOS instance to be created")
}

if dbos.adminServer == nil {
t.Fatal("Expected admin server to be created in DBOS instance")
}

client := &http.Client{Timeout: 5 * time.Second}

tests := []struct {
name string
method string
endpoint string
body io.Reader
contentType string
expectedStatus int
validateResp func(t *testing.T, resp *http.Response)
}{
{
name: "Health endpoint responds correctly",
method: "GET",
endpoint: "http://localhost:3001/dbos-healthz",
expectedStatus: http.StatusOK,
},
{
name: "Recovery endpoint responds correctly with valid JSON",
method: "POST",
endpoint: "http://localhost:3001/dbos-workflow-recovery",
body: bytes.NewBuffer(mustMarshal([]string{"executor1", "executor2"})),
contentType: "application/json",
expectedStatus: http.StatusOK,
validateResp: func(t *testing.T, resp *http.Response) {
var workflowIDs []string
if err := json.NewDecoder(resp.Body).Decode(&workflowIDs); err != nil {
t.Errorf("Failed to decode response as JSON array: %v", err)
}
if workflowIDs == nil {
t.Error("Expected non-nil workflow IDs array")
}
},
},
{
name: "Recovery endpoint rejects invalid methods",
method: "GET",
endpoint: "http://localhost:3001/dbos-workflow-recovery",
expectedStatus: http.StatusMethodNotAllowed,
},
{
name: "Recovery endpoint rejects invalid JSON",
method: "POST",
endpoint: "http://localhost:3001/dbos-workflow-recovery",
body: strings.NewReader(`{"invalid": json}`),
contentType: "application/json",
expectedStatus: http.StatusBadRequest,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var req *http.Request
var err error

if tt.body != nil {
req, err = http.NewRequest(tt.method, tt.endpoint, tt.body)
} else {
req, err = http.NewRequest(tt.method, tt.endpoint, nil)
}
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}

if tt.contentType != "" {
req.Header.Set("Content-Type", tt.contentType)
}

resp, err := client.Do(req)
if err != nil {
t.Fatalf("Failed to make request: %v", err)
}
defer resp.Body.Close()

if resp.StatusCode != tt.expectedStatus {
body, _ := io.ReadAll(resp.Body)
t.Errorf("Expected status code %d, got %d. Response: %s", tt.expectedStatus, resp.StatusCode, string(body))
}

if tt.validateResp != nil {
tt.validateResp(t, resp)
}
})
}
})
}

func mustMarshal(v any) []byte {
data, err := json.Marshal(v)
if err != nil {
panic(err)
}
return data
}
40 changes: 36 additions & 4 deletions dbos/dbos.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@ import (
)

var (
APP_VERSION string
EXECUTOR_ID string
APP_ID string
APP_VERSION string
EXECUTOR_ID string
APP_ID string
DEFAULT_ADMIN_SERVER_PORT = 3001
)

func computeApplicationVersion() string {
Expand Down Expand Up @@ -62,6 +63,7 @@ type executor struct {
queueRunnerCtx context.Context
queueRunnerCancelFunc context.CancelFunc
queueRunnerDone chan struct{}
adminServer *AdminServer
}

var dbos *executor
Expand All @@ -88,7 +90,8 @@ func getLogger() *slog.Logger {
}

type config struct {
logger *slog.Logger
logger *slog.Logger
adminServer bool
}

type LaunchOption func(*config)
Expand All @@ -99,6 +102,12 @@ func WithLogger(logger *slog.Logger) LaunchOption {
}
}

func WithAdminServer() LaunchOption {
return func(config *config) {
config.adminServer = true
}
}

func Launch(options ...LaunchOption) error {
if dbos != nil {
fmt.Println("warning: DBOS instance already initialized, skipping re-initialization")
Expand Down Expand Up @@ -138,6 +147,18 @@ func Launch(options ...LaunchOption) error {

systemDB.Launch(context.Background())

// Start the admin server if configured
var adminServer *AdminServer
if config.adminServer {
adminServer = NewAdminServer(DEFAULT_ADMIN_SERVER_PORT)
err := adminServer.Start()
if err != nil {
logger.Error("Failed to start admin server", "error", err)
return NewInitializationError(fmt.Sprintf("failed to start admin server: %v", err))
}
logger.Info("Admin server started", "port", DEFAULT_ADMIN_SERVER_PORT)
}

// Create context with cancel function for queue runner
ctx, cancel := context.WithCancel(context.Background())

Expand All @@ -146,6 +167,7 @@ func Launch(options ...LaunchOption) error {
queueRunnerCtx: ctx,
queueRunnerCancelFunc: cancel,
queueRunnerDone: make(chan struct{}),
adminServer: adminServer,
}

// Start the queue runner in a goroutine
Expand Down Expand Up @@ -207,6 +229,16 @@ func Shutdown() {
dbos.systemDB = nil
}

if dbos.adminServer != nil {
err := dbos.adminServer.Shutdown()
if err != nil {
getLogger().Error("Failed to shutdown admin server", "error", err)
} else {
getLogger().Info("Admin server shutdown complete")
}
dbos.adminServer = nil
}

if logger != nil {
logger = nil
}
Expand Down
Loading