From 4da89a8770b346b172fcbf3cdcf0cc6d49b5eb15 Mon Sep 17 00:00:00 2001 From: maxdml Date: Tue, 22 Jul 2025 13:03:05 -0700 Subject: [PATCH 1/4] update scopes + DBOSExecutor singleton --- dbos/admin-server.go | 36 +++--- dbos/admin-server_test.go | 44 ++++--- dbos/dbos.go | 169 +++++++++++++------------ dbos/errors.go | 30 ++--- dbos/logger_test.go | 20 ++- dbos/queue.go | 56 ++++----- dbos/queues_test.go | 66 +++++----- dbos/recovery.go | 12 +- dbos/serialization_test.go | 12 +- dbos/system_database.go | 244 ++++++++++++++++++------------------- dbos/utils_test.go | 13 +- dbos/workflow.go | 222 ++++++++++++++++----------------- dbos/workflows_test.go | 40 +++--- 13 files changed, 506 insertions(+), 458 deletions(-) diff --git a/dbos/admin-server.go b/dbos/admin-server.go index 7c9e0dbc..35f18002 100644 --- a/dbos/admin-server.go +++ b/dbos/admin-server.go @@ -9,34 +9,34 @@ import ( ) const ( - HealthCheckPath = "/dbos-healthz" - WorkflowRecoveryPath = "/dbos-workflow-recovery" - WorkflowQueuesMetadataPath = "/dbos-workflow-queues-metadata" + healthCheckPath = "/dbos-healthz" + workflowRecoveryPath = "/dbos-workflow-recovery" + workflowQueuesMetadataPath = "/dbos-workflow-queues-metadata" ) -type AdminServer struct { +type adminServer struct { server *http.Server } -type QueueMetadata struct { +type queueMetadata struct { Name string `json:"name"` Concurrency *int `json:"concurrency,omitempty"` WorkerConcurrency *int `json:"workerConcurrency,omitempty"` RateLimit *RateLimiter `json:"rateLimit,omitempty"` } -func NewAdminServer(port int) *AdminServer { +func newAdminServer(port int) *adminServer { mux := http.NewServeMux() // Health endpoint - mux.HandleFunc(HealthCheckPath, func(w http.ResponseWriter, r *http.Request) { + mux.HandleFunc(healthCheckPath, func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write([]byte(`{"status":"healthy"}`)) }) // Recovery endpoint - mux.HandleFunc(WorkflowRecoveryPath, func(w http.ResponseWriter, r *http.Request) { + mux.HandleFunc(workflowRecoveryPath, func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return @@ -71,21 +71,21 @@ func NewAdminServer(port int) *AdminServer { }) // Queue metadata endpoint - mux.HandleFunc(WorkflowQueuesMetadataPath, func(w http.ResponseWriter, r *http.Request) { + mux.HandleFunc(workflowQueuesMetadataPath, func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } - var queueMetadataArray []QueueMetadata + var queueMetadataArray []queueMetadata // Iterate through all queues in the registry for _, queue := range workflowQueueRegistry { - queueMetadata := QueueMetadata{ - Name: queue.Name, - WorkerConcurrency: queue.WorkerConcurrency, - Concurrency: queue.GlobalConcurrency, - RateLimit: queue.Limiter, + queueMetadata := queueMetadata{ + Name: queue.name, + WorkerConcurrency: queue.workerConcurrency, + Concurrency: queue.globalConcurrency, + RateLimit: queue.limiter, } queueMetadataArray = append(queueMetadataArray, queueMetadata) @@ -103,12 +103,12 @@ func NewAdminServer(port int) *AdminServer { Handler: mux, } - return &AdminServer{ + return &adminServer{ server: server, } } -func (as *AdminServer) Start() error { +func (as *adminServer) Start() error { getLogger().Info("Starting admin server", "port", 3001) go func() { @@ -120,7 +120,7 @@ func (as *AdminServer) Start() error { return nil } -func (as *AdminServer) Shutdown() error { +func (as *adminServer) Shutdown() error { getLogger().Info("Shutting down admin server") ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) diff --git a/dbos/admin-server_test.go b/dbos/admin-server_test.go index 5b8bb724..b287aee8 100644 --- a/dbos/admin-server_test.go +++ b/dbos/admin-server_test.go @@ -21,24 +21,32 @@ func TestAdminServer(t *testing.T) { t.Run("Admin server is not started without WithAdminServer option", func(t *testing.T) { // Ensure clean state if dbos != nil { - Shutdown() + dbos.Shutdown() } // Launch DBOS without admin server option - err := Launch() + executor, err := NewExecutor() + if err != nil { + t.Skipf("Failed to create DBOS (database likely not available): %v", err) + } + err = executor.Launch() if err != nil { t.Skipf("Failed to launch DBOS (database likely not available): %v", err) } // Ensure cleanup - defer Shutdown() + defer func() { + if executor != nil { + executor.Shutdown() + } + }() // Give time for any startup processes time.Sleep(100 * time.Millisecond) // Verify admin server is not running client := &http.Client{Timeout: 1 * time.Second} - _, err = client.Get("http://localhost:3001" + HealthCheckPath) + _, err = client.Get("http://localhost:3001" + healthCheckPath) if err == nil { t.Error("Expected request to fail when admin server is not started, but it succeeded") } @@ -56,17 +64,25 @@ func TestAdminServer(t *testing.T) { t.Run("Admin server endpoints", func(t *testing.T) { // Ensure clean state if dbos != nil { - Shutdown() + dbos.Shutdown() } // Launch DBOS with admin server once for all endpoint tests - err := Launch(WithAdminServer()) + executor, err := NewExecutor(WithAdminServer()) + if err != nil { + t.Skipf("Failed to create DBOS with admin server (database likely not available): %v", err) + } + err = executor.Launch() if err != nil { t.Skipf("Failed to launch DBOS with admin server (database likely not available): %v", err) } // Ensure cleanup - defer Shutdown() + defer func() { + if executor != nil { + executor.Shutdown() + } + }() // Give the server a moment to start time.Sleep(100 * time.Millisecond) @@ -94,13 +110,13 @@ func TestAdminServer(t *testing.T) { { name: "Health endpoint responds correctly", method: "GET", - endpoint: "http://localhost:3001" + HealthCheckPath, + endpoint: "http://localhost:3001" + healthCheckPath, expectedStatus: http.StatusOK, }, { name: "Recovery endpoint responds correctly with valid JSON", method: "POST", - endpoint: "http://localhost:3001" + WorkflowRecoveryPath, + endpoint: "http://localhost:3001" + workflowRecoveryPath, body: bytes.NewBuffer(mustMarshal([]string{"executor1", "executor2"})), contentType: "application/json", expectedStatus: http.StatusOK, @@ -117,13 +133,13 @@ func TestAdminServer(t *testing.T) { { name: "Recovery endpoint rejects invalid methods", method: "GET", - endpoint: "http://localhost:3001" + WorkflowRecoveryPath, + endpoint: "http://localhost:3001" + workflowRecoveryPath, expectedStatus: http.StatusMethodNotAllowed, }, { name: "Recovery endpoint rejects invalid JSON", method: "POST", - endpoint: "http://localhost:3001" + WorkflowRecoveryPath, + endpoint: "http://localhost:3001" + workflowRecoveryPath, body: strings.NewReader(`{"invalid": json}`), contentType: "application/json", expectedStatus: http.StatusBadRequest, @@ -131,10 +147,10 @@ func TestAdminServer(t *testing.T) { { name: "Queue metadata endpoint responds correctly", method: "GET", - endpoint: "http://localhost:3001" + WorkflowQueuesMetadataPath, + endpoint: "http://localhost:3001" + workflowQueuesMetadataPath, expectedStatus: http.StatusOK, validateResp: func(t *testing.T, resp *http.Response) { - var queueMetadata []QueueMetadata + var queueMetadata []queueMetadata if err := json.NewDecoder(resp.Body).Decode(&queueMetadata); err != nil { t.Errorf("Failed to decode response as QueueMetadata array: %v", err) } @@ -170,7 +186,7 @@ func TestAdminServer(t *testing.T) { { name: "Queue metadata endpoint rejects invalid methods", method: "POST", - endpoint: "http://localhost:3001" + WorkflowQueuesMetadataPath, + endpoint: "http://localhost:3001" + workflowQueuesMetadataPath, expectedStatus: http.StatusMethodNotAllowed, }, } diff --git a/dbos/dbos.go b/dbos/dbos.go index a09d9a79..04ac7344 100644 --- a/dbos/dbos.go +++ b/dbos/dbos.go @@ -17,10 +17,10 @@ import ( ) var ( - APP_VERSION string - EXECUTOR_ID string - APP_ID string - DEFAULT_ADMIN_SERVER_PORT = 3001 + _APP_VERSION string + _EXECUTOR_ID string + _APP_ID string + _DEFAULT_ADMIN_SERVER_PORT = 3001 ) func computeApplicationVersion() string { @@ -53,30 +53,9 @@ func computeApplicationVersion() string { } -type Executor interface { - Shutdown() -} - -var workflowScheduler *cron.Cron - -type executor struct { - systemDB SystemDatabase - queueRunnerCtx context.Context - queueRunnerCancelFunc context.CancelFunc - queueRunnerDone chan struct{} - adminServer *AdminServer -} - -var dbos *executor +var workflowScheduler *cron.Cron // Global because accessed during workflow registration before the dbos singleton is initialized -func getExecutor() *executor { - if dbos == nil { - return nil - } - return dbos -} - -var logger *slog.Logger +var logger *slog.Logger // Global because accessed everywhere inside the library func getLogger() *slog.Logger { if dbos == nil || logger == nil { @@ -124,93 +103,124 @@ func NewConfig(programmaticConfig config) *config { return dbosConfig } -type LaunchOption func(*config) +var dbos *Executor // DBOS singleton instance + +type Executor struct { + systemDB SystemDatabase + queueRunnerCtx context.Context + queueRunnerCancelFunc context.CancelFunc + queueRunnerDone chan struct{} + adminServer *adminServer + config *config +} + +type executorOption func(*config) -func WithLogger(logger *slog.Logger) LaunchOption { +func WithLogger(logger *slog.Logger) executorOption { return func(config *config) { config.logger = logger } } -func WithAdminServer() LaunchOption { +func WithAdminServer() executorOption { return func(config *config) { config.adminServer = true } } -func WithDatabaseURL(url string) LaunchOption { +func WithDatabaseURL(url string) executorOption { return func(config *config) { config.databaseURL = url } } -func Launch(options ...LaunchOption) error { +func WithAppName(name string) executorOption { + return func(config *config) { + config.appName = name + } +} + +func NewExecutor(options ...executorOption) (*Executor, error) { if dbos != nil { fmt.Println("warning: DBOS instance already initialized, skipping re-initialization") - return NewInitializationError("DBOS already initialized") + return nil, newInitializationError("DBOS already initialized") } - // Load & process the configuration + // Start with default configuration config := &config{ logger: slog.New(slog.NewTextHandler(os.Stderr, nil)), } + + // Apply options for _, option := range options { option(config) } + + // Load & process the configuration config = NewConfig(*config) + // Set global logger logger = config.logger - // Initialize with environment variables, providing defaults if not set - APP_VERSION = os.Getenv("DBOS__APPVERSION") - if APP_VERSION == "" { - APP_VERSION = computeApplicationVersion() + // Initialize global variables with environment variables, providing defaults if not set + _APP_VERSION = os.Getenv("DBOS__APPVERSION") + if _APP_VERSION == "" { + _APP_VERSION = computeApplicationVersion() logger.Info("DBOS__APPVERSION not set, using computed hash") } - EXECUTOR_ID = os.Getenv("DBOS__VMID") - if EXECUTOR_ID == "" { - EXECUTOR_ID = "local" - logger.Info("DBOS__VMID not set, using default", "executor_id", EXECUTOR_ID) + _EXECUTOR_ID = os.Getenv("DBOS__VMID") + if _EXECUTOR_ID == "" { + _EXECUTOR_ID = "local" + logger.Info("DBOS__VMID not set, using default", "executor_id", _EXECUTOR_ID) } - APP_ID = os.Getenv("DBOS__APPID") + _APP_ID = os.Getenv("DBOS__APPID") // Create the system database systemDB, err := NewSystemDatabase(config.databaseURL) if err != nil { - return NewInitializationError(fmt.Sprintf("failed to create system database: %v", err)) + return nil, newInitializationError(fmt.Sprintf("failed to create system database: %v", err)) } logger.Info("System database initialized") - systemDB.Launch(context.Background()) + // Create the executor instance + executor := &Executor{ + config: config, + systemDB: systemDB, + } + + // Set the global dbos instance + dbos = executor + + return executor, nil +} + +func (e *Executor) Launch() error { + // Start the system database + e.systemDB.Launch(context.Background()) // Start the admin server if configured - var adminServer *AdminServer - if config.adminServer { - adminServer = NewAdminServer(DEFAULT_ADMIN_SERVER_PORT) + if e.config.adminServer { + adminServer := newAdminServer(_DEFAULT_ADMIN_SERVER_PORT) err := adminServer.Start() if err != nil { logger.Error("Failed to start admin server", "error", err) - return NewInitializationError(fmt.Sprintf("failed to start admin server: %v", err)) + return newInitializationError(fmt.Sprintf("failed to start admin server: %v", err)) } - logger.Info("Admin server started", "port", DEFAULT_ADMIN_SERVER_PORT) + logger.Info("Admin server started", "port", _DEFAULT_ADMIN_SERVER_PORT) + e.adminServer = adminServer } // Create context with cancel function for queue runner ctx, cancel := context.WithCancel(context.Background()) - - dbos = &executor{ - systemDB: systemDB, - queueRunnerCtx: ctx, - queueRunnerCancelFunc: cancel, - queueRunnerDone: make(chan struct{}), - adminServer: adminServer, - } + e.queueRunnerCtx = ctx + e.queueRunnerCancelFunc = cancel + e.queueRunnerDone = make(chan struct{}) // Start the queue runner in a goroutine go func() { - defer close(dbos.queueRunnerDone) + defer close(e.queueRunnerDone) queueRunner(ctx) }() logger.Info("Queue runner started") @@ -222,29 +232,28 @@ func Launch(options ...LaunchOption) error { } // Run a round of recovery on the local executor - _, err = recoverPendingWorkflows(context.Background(), []string{EXECUTOR_ID}) // XXX maybe use the queue runner context here to allow Shutdown to cancel it? + _, err := recoverPendingWorkflows(context.Background(), []string{_EXECUTOR_ID}) // XXX maybe use the queue runner context here to allow Shutdown to cancel it? if err != nil { - return NewInitializationError(fmt.Sprintf("failed to recover pending workflows during launch: %v", err)) + return newInitializationError(fmt.Sprintf("failed to recover pending workflows during launch: %v", err)) } - logger.Info("DBOS initialized", "app_version", APP_VERSION, "executor_id", EXECUTOR_ID) + logger.Info("DBOS initialized", "app_version", _APP_VERSION, "executor_id", _EXECUTOR_ID) return nil } -// Close closes the DBOS instance and its resources -func Shutdown() { - if dbos == nil { - fmt.Println("DBOS instance is nil, cannot destroy") +func (e *Executor) Shutdown() { + if e == nil { + fmt.Println("Executor instance is nil, cannot shutdown") return } // XXX is there a way to ensure all workflows goroutine are done before closing? // Cancel the context to stop the queue runner - if dbos.queueRunnerCancelFunc != nil { - dbos.queueRunnerCancelFunc() + if e.queueRunnerCancelFunc != nil { + e.queueRunnerCancelFunc() // Wait for queue runner to finish - <-dbos.queueRunnerDone + <-e.queueRunnerDone getLogger().Info("Queue runner stopped") } @@ -262,24 +271,26 @@ func Shutdown() { } } - if dbos.systemDB != nil { - dbos.systemDB.Shutdown() - dbos.systemDB = nil + if e.systemDB != nil { + e.systemDB.Shutdown() + e.systemDB = nil } - if dbos.adminServer != nil { - err := dbos.adminServer.Shutdown() + if e.adminServer != nil { + err := e.adminServer.Shutdown() if err != nil { getLogger().Error("Failed to shutdown admin server", "error", err) } else { getLogger().Info("Admin server shutdown complete") } - dbos.adminServer = nil + e.adminServer = nil } - if logger != nil { - logger = nil + // Clear global references if this is the global instance + if dbos == e { + if logger != nil { + logger = nil + } + dbos = nil } - - dbos = nil } diff --git a/dbos/errors.go b/dbos/errors.go index 33d8ef63..b6131a16 100644 --- a/dbos/errors.go +++ b/dbos/errors.go @@ -45,7 +45,7 @@ func (e *DBOSError) Error() string { return fmt.Sprintf("DBOS Error %d: %s", int(e.Code), e.Message) } -func NewConflictingWorkflowError(workflowID, message string) *DBOSError { +func newConflictingWorkflowError(workflowID, message string) *DBOSError { msg := fmt.Sprintf("Conflicting workflow invocation with the same ID (%s)", workflowID) if message != "" { msg += ": " + message @@ -57,14 +57,14 @@ func NewConflictingWorkflowError(workflowID, message string) *DBOSError { } } -func NewInitializationError(message string) *DBOSError { +func newInitializationError(message string) *DBOSError { return &DBOSError{ Message: fmt.Sprintf("Error initializing DBOS Transact: %s", message), Code: InitializationError, } } -func NewWorkflowFunctionNotFoundError(workflowID, message string) *DBOSError { +func newWorkflowFunctionNotFoundError(workflowID, message string) *DBOSError { msg := fmt.Sprintf("Workflow function not found for workflow ID %s", workflowID) if message != "" { msg += ": " + message @@ -76,7 +76,7 @@ func NewWorkflowFunctionNotFoundError(workflowID, message string) *DBOSError { } } -func NewNonExistentWorkflowError(workflowID string) *DBOSError { +func newNonExistentWorkflowError(workflowID string) *DBOSError { return &DBOSError{ Message: fmt.Sprintf("workflow %s does not exist", workflowID), Code: NonExistentWorkflowError, @@ -84,14 +84,14 @@ func NewNonExistentWorkflowError(workflowID string) *DBOSError { } } -func NewConflictingRegistrationError(name string) *DBOSError { +func newConflictingRegistrationError(name string) *DBOSError { return &DBOSError{ Message: fmt.Sprintf("%s is already registered", name), Code: ConflictingRegistrationError, } } -func NewUnexpectedStepError(workflowID string, stepID int, expectedName, recordedName string) *DBOSError { +func newUnexpectedStepError(workflowID string, stepID int, expectedName, recordedName string) *DBOSError { return &DBOSError{ Message: fmt.Sprintf("During execution of workflow %s step %d, function %s was recorded when %s was expected. Check that your workflow is deterministic.", workflowID, stepID, recordedName, expectedName), Code: UnexpectedStep, @@ -102,7 +102,7 @@ func NewUnexpectedStepError(workflowID string, stepID int, expectedName, recorde } } -func NewAwaitedWorkflowCancelledError(workflowID string) *DBOSError { +func newAwaitedWorkflowCancelledError(workflowID string) *DBOSError { return &DBOSError{ Message: fmt.Sprintf("Awaited workflow %s was cancelled", workflowID), Code: AwaitedWorkflowCancelled, @@ -110,7 +110,7 @@ func NewAwaitedWorkflowCancelledError(workflowID string) *DBOSError { } } -func NewWorkflowCancelledError(workflowID string) *DBOSError { +func newWorkflowCancelledError(workflowID string) *DBOSError { return &DBOSError{ Message: fmt.Sprintf("Workflow %s was cancelled", workflowID), Code: WorkflowCancelled, @@ -118,7 +118,7 @@ func NewWorkflowCancelledError(workflowID string) *DBOSError { } } -func NewWorkflowConflictIDError(workflowID string) *DBOSError { +func newWorkflowConflictIDError(workflowID string) *DBOSError { return &DBOSError{ Message: fmt.Sprintf("Conflicting workflow ID %s", workflowID), Code: ConflictingIDError, @@ -127,7 +127,7 @@ func NewWorkflowConflictIDError(workflowID string) *DBOSError { } } -func NewWorkflowUnexpectedResultType(workflowID, expectedType, actualType string) *DBOSError { +func newWorkflowUnexpectedResultType(workflowID, expectedType, actualType string) *DBOSError { return &DBOSError{ Message: fmt.Sprintf("Workflow %s returned unexpected result type: expected %s, got %s", workflowID, expectedType, actualType), Code: WorkflowUnexpectedTypeError, @@ -136,7 +136,7 @@ func NewWorkflowUnexpectedResultType(workflowID, expectedType, actualType string } } -func NewWorkflowUnexpectedInputType(workflowName, expectedType, actualType string) *DBOSError { +func newWorkflowUnexpectedInputType(workflowName, expectedType, actualType string) *DBOSError { return &DBOSError{ Message: fmt.Sprintf("Workflow %s received unexpected input type: expected %s, got %s", workflowName, expectedType, actualType), Code: WorkflowUnexpectedTypeError, @@ -144,7 +144,7 @@ func NewWorkflowUnexpectedInputType(workflowName, expectedType, actualType strin } } -func NewWorkflowExecutionError(workflowID, message string) *DBOSError { +func newWorkflowExecutionError(workflowID, message string) *DBOSError { return &DBOSError{ Message: fmt.Sprintf("Workflow %s execution error: %s", workflowID, message), Code: WorkflowExecutionError, @@ -153,7 +153,7 @@ func NewWorkflowExecutionError(workflowID, message string) *DBOSError { } } -func NewStepExecutionError(workflowID, stepName, message string) *DBOSError { +func newStepExecutionError(workflowID, stepName, message string) *DBOSError { return &DBOSError{ Message: fmt.Sprintf("Step %s in workflow %s execution error: %s", stepName, workflowID, message), Code: StepExecutionError, @@ -163,7 +163,7 @@ func NewStepExecutionError(workflowID, stepName, message string) *DBOSError { } } -func NewDeadLetterQueueError(workflowID string, maxRetries int) *DBOSError { +func newDeadLetterQueueError(workflowID string, maxRetries int) *DBOSError { return &DBOSError{ Message: fmt.Sprintf("Workflow %s has been moved to the dead-letter queue after exceeding the maximum of %d retries", workflowID, maxRetries), Code: DeadLetterQueueError, @@ -173,7 +173,7 @@ func NewDeadLetterQueueError(workflowID string, maxRetries int) *DBOSError { } } -func NewMaxStepRetriesExceededError(workflowID, stepName string, maxRetries int, err error) *DBOSError { +func newMaxStepRetriesExceededError(workflowID, stepName string, maxRetries int, err error) *DBOSError { return &DBOSError{ Message: fmt.Sprintf("Step %s has exceeded its maximum of %d retries: %v", stepName, maxRetries, err), Code: MaxStepRetriesExceeded, diff --git a/dbos/logger_test.go b/dbos/logger_test.go index 018b1a02..b3c82df7 100644 --- a/dbos/logger_test.go +++ b/dbos/logger_test.go @@ -10,12 +10,18 @@ import ( func TestLogger(t *testing.T) { t.Run("Default logger", func(t *testing.T) { - err := Launch() // Launch with default logger + executor, err := NewExecutor() // Create executor with default logger + if err != nil { + t.Fatalf("Failed to create executor with default logger: %v", err) + } + err = executor.Launch() if err != nil { t.Fatalf("Failed to launch with default logger: %v", err) } t.Cleanup(func() { - Shutdown() + if executor != nil { + executor.Shutdown() + } }) if logger == nil { @@ -37,12 +43,18 @@ func TestLogger(t *testing.T) { // Add some context to the slog logger slogLogger = slogLogger.With("service", "dbos-test", "environment", "test") - err := Launch(WithLogger(slogLogger)) + executor, err := NewExecutor(WithLogger(slogLogger)) + if err != nil { + t.Fatalf("Failed to create executor with custom logger: %v", err) + } + err = executor.Launch() if err != nil { t.Fatalf("Failed to launch with custom logger: %v", err) } t.Cleanup(func() { - Shutdown() + if executor != nil { + executor.Shutdown() + } }) if logger == nil { diff --git a/dbos/queue.go b/dbos/queue.go index b12bba2d..6850ba98 100644 --- a/dbos/queue.go +++ b/dbos/queue.go @@ -30,65 +30,65 @@ type RateLimiter struct { } type WorkflowQueue struct { - Name string - WorkerConcurrency *int - GlobalConcurrency *int - PriorityEnabled bool - Limiter *RateLimiter - MaxTasksPerIteration int + name string + workerConcurrency *int + globalConcurrency *int + priorityEnabled bool + limiter *RateLimiter + maxTasksPerIteration int } -// QueueOption is a functional option for configuring a workflow queue -type QueueOption func(*WorkflowQueue) +// queueOption is a functional option for configuring a workflow queue +type queueOption func(*WorkflowQueue) -func WithWorkerConcurrency(concurrency int) QueueOption { +func WithWorkerConcurrency(concurrency int) queueOption { return func(q *WorkflowQueue) { - q.WorkerConcurrency = &concurrency + q.workerConcurrency = &concurrency } } -func WithGlobalConcurrency(concurrency int) QueueOption { +func WithGlobalConcurrency(concurrency int) queueOption { return func(q *WorkflowQueue) { - q.GlobalConcurrency = &concurrency + q.globalConcurrency = &concurrency } } -func WithPriorityEnabled(enabled bool) QueueOption { +func WithPriorityEnabled(enabled bool) queueOption { return func(q *WorkflowQueue) { - q.PriorityEnabled = enabled + q.priorityEnabled = enabled } } -func WithRateLimiter(limiter *RateLimiter) QueueOption { +func WithRateLimiter(limiter *RateLimiter) queueOption { return func(q *WorkflowQueue) { - q.Limiter = limiter + q.limiter = limiter } } -func WithMaxTasksPerIteration(maxTasks int) QueueOption { +func WithMaxTasksPerIteration(maxTasks int) queueOption { return func(q *WorkflowQueue) { - q.MaxTasksPerIteration = maxTasks + q.maxTasksPerIteration = maxTasks } } // NewWorkflowQueue creates a new workflow queue with optional configuration -func NewWorkflowQueue(name string, options ...QueueOption) WorkflowQueue { - if getExecutor() != nil { +func NewWorkflowQueue(name string, options ...queueOption) WorkflowQueue { + if dbos != nil { getLogger().Warn("NewWorkflowQueue called after DBOS initialization, dynamic registration is not supported") return WorkflowQueue{} } if _, exists := workflowQueueRegistry[name]; exists { - panic(NewConflictingRegistrationError(name)) + panic(newConflictingRegistrationError(name)) } // Create queue with default settings q := WorkflowQueue{ - Name: name, - WorkerConcurrency: nil, - GlobalConcurrency: nil, - PriorityEnabled: false, - Limiter: nil, - MaxTasksPerIteration: _DEFAULT_MAX_TASKS_PER_ITERATION, + name: name, + workerConcurrency: nil, + globalConcurrency: nil, + priorityEnabled: false, + limiter: nil, + maxTasksPerIteration: _DEFAULT_MAX_TASKS_PER_ITERATION, } // Apply functional options @@ -122,7 +122,7 @@ func queueRunner(ctx context.Context) { for queueName, queue := range workflowQueueRegistry { getLogger().Debug("Processing queue", "queue_name", queueName) // Call DequeueWorkflows for each queue - dequeuedWorkflows, err := getExecutor().systemDB.DequeueWorkflows(ctx, queue) + dequeuedWorkflows, err := dbos.systemDB.DequeueWorkflows(ctx, queue) if err != nil { if pgErr, ok := err.(*pgconn.PgError); ok { switch pgErr.Code { diff --git a/dbos/queues_test.go b/dbos/queues_test.go index 310b9828..6e7ad626 100644 --- a/dbos/queues_test.go +++ b/dbos/queues_test.go @@ -74,7 +74,7 @@ func queueWorkflowWithChild(ctx context.Context, input string) (string, error) { func queueWorkflowThatEnqueues(ctx context.Context, input string) (string, error) { // Enqueue another workflow to the same queue - enqueuedHandle, err := queueWf(ctx, input+"-enqueued", WithQueue(queue.Name)) + enqueuedHandle, err := queueWf(ctx, input+"-enqueued", WithQueue(queue.name)) if err != nil { return "", fmt.Errorf("failed to enqueue workflow: %v", err) } @@ -92,7 +92,7 @@ func TestWorkflowQueues(t *testing.T) { setupDBOS(t) t.Run("EnqueueWorkflow", func(t *testing.T) { - handle, err := queueWf(context.Background(), "test-input", WithQueue(queue.Name)) + handle, err := queueWf(context.Background(), "test-input", WithQueue(queue.name)) if err != nil { t.Fatalf("failed to enqueue workflow: %v", err) } @@ -116,7 +116,7 @@ func TestWorkflowQueues(t *testing.T) { }) t.Run("EnqueuedWorkflowStartsChildWorkflow", func(t *testing.T) { - handle, err := queueWfWithChild(context.Background(), "test-input", WithQueue(queue.Name)) + handle, err := queueWfWithChild(context.Background(), "test-input", WithQueue(queue.name)) if err != nil { t.Fatalf("failed to enqueue workflow with child: %v", err) } @@ -138,7 +138,7 @@ func TestWorkflowQueues(t *testing.T) { }) t.Run("WorkflowEnqueuesAnotherWorkflow", func(t *testing.T) { - handle, err := queueWfThatEnqueues(context.Background(), "test-input", WithQueue(queue.Name)) + handle, err := queueWfThatEnqueues(context.Background(), "test-input", WithQueue(queue.name)) if err != nil { t.Fatalf("failed to enqueue workflow that enqueues another workflow: %v", err) } @@ -161,7 +161,7 @@ func TestWorkflowQueues(t *testing.T) { t.Run("DynamicRegistration", func(t *testing.T) { q := NewWorkflowQueue("dynamic-queue") - if len(q.Name) > 0 { + if len(q.name) > 0 { t.Fatalf("expected nil queue for dynamic registration after DBOS initialization, got %v", q) } }) @@ -170,7 +170,7 @@ func TestWorkflowQueues(t *testing.T) { workflowID := "blocking-workflow-test" // Enqueue the workflow for the first time - originalHandle, err := enqueueWorkflowDLQ(context.Background(), "test-input", WithQueue(dlqEnqueueQueue.Name), WithWorkflowID(workflowID)) + originalHandle, err := enqueueWorkflowDLQ(context.Background(), "test-input", WithQueue(dlqEnqueueQueue.name), WithWorkflowID(workflowID)) if err != nil { t.Fatalf("failed to enqueue blocking workflow: %v", err) } @@ -181,7 +181,7 @@ func TestWorkflowQueues(t *testing.T) { // Try to enqueue the same workflow more times for i := range dlqMaxRetries * 2 { - _, err := enqueueWorkflowDLQ(context.Background(), "test-input", WithQueue(dlqEnqueueQueue.Name), WithWorkflowID(workflowID)) + _, err := enqueueWorkflowDLQ(context.Background(), "test-input", WithQueue(dlqEnqueueQueue.name), WithWorkflowID(workflowID)) if err != nil { t.Fatalf("failed to enqueue workflow attempt %d: %v", i+1, err) } @@ -259,7 +259,7 @@ var ( recoveryWorkflow = WithWorkflow(func(ctx context.Context, input string) ([]int, error) { handles := make([]WorkflowHandle[int], 0, 5) // 5 queued steps for i := range 5 { - handle, err := recoveryStepWorkflow(ctx, i, WithQueue(recoveryQueue.Name)) + handle, err := recoveryStepWorkflow(ctx, i, WithQueue(recoveryQueue.name)) if err != nil { return nil, fmt.Errorf("failed to enqueue step %d: %v", i, err) } @@ -393,12 +393,12 @@ func TestGlobalConcurrency(t *testing.T) { setupDBOS(t) // Enqueue two workflows - handle1, err := globalConcurrencyWorkflow(context.Background(), "workflow1", WithQueue(globalConcurrencyQueue.Name)) + handle1, err := globalConcurrencyWorkflow(context.Background(), "workflow1", WithQueue(globalConcurrencyQueue.name)) if err != nil { t.Fatalf("failed to enqueue workflow1: %v", err) } - handle2, err := globalConcurrencyWorkflow(context.Background(), "workflow2", WithQueue(globalConcurrencyQueue.Name)) + handle2, err := globalConcurrencyWorkflow(context.Background(), "workflow2", WithQueue(globalConcurrencyQueue.name)) if err != nil { t.Fatalf("failed to enqueue workflow2: %v", err) } @@ -471,19 +471,19 @@ func TestWorkerConcurrency(t *testing.T) { setupDBOS(t) // First enqueue four blocking workflows - handle1, err := blockingWf(context.Background(), 0, WithQueue(workerConcurrencyQueue.Name), WithWorkflowID("worker-cc-wf-1")) + handle1, err := blockingWf(context.Background(), 0, WithQueue(workerConcurrencyQueue.name), WithWorkflowID("worker-cc-wf-1")) if err != nil { t.Fatalf("failed to enqueue blocking workflow 1: %v", err) } - handle2, err := blockingWf(context.Background(), 1, WithQueue(workerConcurrencyQueue.Name), WithWorkflowID("worker-cc-wf-2")) + handle2, err := blockingWf(context.Background(), 1, WithQueue(workerConcurrencyQueue.name), WithWorkflowID("worker-cc-wf-2")) if err != nil { t.Fatalf("failed to enqueue blocking workflow 2: %v", err) } - _, err = blockingWf(context.Background(), 2, WithQueue(workerConcurrencyQueue.Name), WithWorkflowID("worker-cc-wf-3")) + _, err = blockingWf(context.Background(), 2, WithQueue(workerConcurrencyQueue.name), WithWorkflowID("worker-cc-wf-3")) if err != nil { t.Fatalf("failed to enqueue blocking workflow 3: %v", err) } - _, err = blockingWf(context.Background(), 3, WithQueue(workerConcurrencyQueue.Name), WithWorkflowID("worker-cc-wf-4")) + _, err = blockingWf(context.Background(), 3, WithQueue(workerConcurrencyQueue.name), WithWorkflowID("worker-cc-wf-4")) if err != nil { t.Fatalf("failed to enqueue blocking workflow 4: %v", err) } @@ -494,9 +494,9 @@ func TestWorkerConcurrency(t *testing.T) { if startEvents[1].IsSet || startEvents[2].IsSet || startEvents[3].IsSet { t.Fatal("expected only blocking workflow 1 to start, but others have started") } - workflows, err := getExecutor().systemDB.ListWorkflows(context.Background(), ListWorkflowsDBInput{ - Status: []WorkflowStatusType{WorkflowStatusEnqueued}, - QueueName: workerConcurrencyQueue.Name, + workflows, err := dbos.systemDB.ListWorkflows(context.Background(), listWorkflowsDBInput{ + status: []WorkflowStatusType{WorkflowStatusEnqueued}, + queueName: workerConcurrencyQueue.name, }) if err != nil { t.Fatalf("failed to list workflows: %v", err) @@ -508,7 +508,7 @@ func TestWorkerConcurrency(t *testing.T) { // Stop the queue runner before changing executor ID to avoid race conditions stopQueueRunner() // Change the EXECUTOR_ID global variable to a different value - EXECUTOR_ID = "worker-2" + _EXECUTOR_ID = "worker-2" // Restart the queue runner restartQueueRunner() @@ -518,9 +518,9 @@ func TestWorkerConcurrency(t *testing.T) { if startEvents[2].IsSet || startEvents[3].IsSet { t.Fatal("expected only blocking workflow 2 to start, but others have started") } - workflows, err = getExecutor().systemDB.ListWorkflows(context.Background(), ListWorkflowsDBInput{ - Status: []WorkflowStatusType{WorkflowStatusEnqueued}, - QueueName: workerConcurrencyQueue.Name, + workflows, err = dbos.systemDB.ListWorkflows(context.Background(), listWorkflowsDBInput{ + status: []WorkflowStatusType{WorkflowStatusEnqueued}, + queueName: workerConcurrencyQueue.name, }) if err != nil { t.Fatalf("failed to list workflows: %v", err) @@ -541,7 +541,7 @@ func TestWorkerConcurrency(t *testing.T) { // Stop the queue runner before changing executor ID to avoid race conditions stopQueueRunner() // Change the executor again and wait for the third workflow to start - EXECUTOR_ID = "local" + _EXECUTOR_ID = "local" // Restart the queue runner restartQueueRunner() startEvents[2].Wait() @@ -550,9 +550,9 @@ func TestWorkerConcurrency(t *testing.T) { t.Fatal("expected only blocking workflow 3 to start, but workflow 4 has started") } // Check that only one workflow is pending - workflows, err = getExecutor().systemDB.ListWorkflows(context.Background(), ListWorkflowsDBInput{ - Status: []WorkflowStatusType{WorkflowStatusEnqueued}, - QueueName: workerConcurrencyQueue.Name, + workflows, err = dbos.systemDB.ListWorkflows(context.Background(), listWorkflowsDBInput{ + status: []WorkflowStatusType{WorkflowStatusEnqueued}, + queueName: workerConcurrencyQueue.name, }) if err != nil { t.Fatalf("failed to list workflows: %v", err) @@ -573,14 +573,14 @@ func TestWorkerConcurrency(t *testing.T) { // Stop the queue runner before changing executor ID to avoid race conditions stopQueueRunner() // change executor again and wait for the fourth workflow to start - EXECUTOR_ID = "worker-2" + _EXECUTOR_ID = "worker-2" // Restart the queue runner restartQueueRunner() startEvents[3].Wait() // Check no workflow is enqueued - workflows, err = getExecutor().systemDB.ListWorkflows(context.Background(), ListWorkflowsDBInput{ - Status: []WorkflowStatusType{WorkflowStatusEnqueued}, - QueueName: workerConcurrencyQueue.Name, + workflows, err = dbos.systemDB.ListWorkflows(context.Background(), listWorkflowsDBInput{ + status: []WorkflowStatusType{WorkflowStatusEnqueued}, + queueName: workerConcurrencyQueue.name, }) if err != nil { t.Fatalf("failed to list workflows: %v", err) @@ -597,7 +597,7 @@ func TestWorkerConcurrency(t *testing.T) { t.Fatal("expected queue entries to be cleaned up after global concurrency test") } - EXECUTOR_ID = "local" // Reset executor ID for future tests + _EXECUTOR_ID = "local" // Reset executor ID for future tests } var ( @@ -622,11 +622,11 @@ func TestWorkerConcurrencyXRecovery(t *testing.T) { setupDBOS(t) // Enqueue two workflows on a queue with worker concurrency = 1 - handle1, err := workerConcurrencyRecoveryBlockingWf1(context.Background(), "workflow1", WithQueue(workerConcurrencyRecoveryQueue.Name), WithWorkflowID("worker-cc-x-recovery-wf-1")) + handle1, err := workerConcurrencyRecoveryBlockingWf1(context.Background(), "workflow1", WithQueue(workerConcurrencyRecoveryQueue.name), WithWorkflowID("worker-cc-x-recovery-wf-1")) if err != nil { t.Fatalf("failed to enqueue blocking workflow 1: %v", err) } - handle2, err := workerConcurrencyRecoveryBlockingWf2(context.Background(), "workflow2", WithQueue(workerConcurrencyRecoveryQueue.Name), WithWorkflowID("worker-cc-x-recovery-wf-2")) + handle2, err := workerConcurrencyRecoveryBlockingWf2(context.Background(), "workflow2", WithQueue(workerConcurrencyRecoveryQueue.name), WithWorkflowID("worker-cc-x-recovery-wf-2")) if err != nil { t.Fatalf("failed to enqueue blocking workflow 2: %v", err) } @@ -736,7 +736,7 @@ func TestQueueRateLimiter(t *testing.T) { // executed simultaneously, followed by a wait of the period, // followed by the next wave. for i := 0; i < limit*numWaves; i++ { - handle, err := rateLimiterWorkflow(context.Background(), "", WithQueue(rateLimiterQueue.Name)) + handle, err := rateLimiterWorkflow(context.Background(), "", WithQueue(rateLimiterQueue.name)) if err != nil { t.Fatalf("failed to enqueue workflow %d: %v", i, err) } diff --git a/dbos/recovery.go b/dbos/recovery.go index 5bc3bf75..3cf03f7b 100644 --- a/dbos/recovery.go +++ b/dbos/recovery.go @@ -8,10 +8,10 @@ import ( func recoverPendingWorkflows(ctx context.Context, executorIDs []string) ([]WorkflowHandle[any], error) { workflowHandles := make([]WorkflowHandle[any], 0) // List pending workflows for the executors - pendingWorkflows, err := getExecutor().systemDB.ListWorkflows(ctx, ListWorkflowsDBInput{ - Status: []WorkflowStatusType{WorkflowStatusPending}, - ExecutorIDs: executorIDs, - ApplicationVersion: APP_VERSION, + pendingWorkflows, err := dbos.systemDB.ListWorkflows(ctx, listWorkflowsDBInput{ + status: []WorkflowStatusType{WorkflowStatusPending}, + executorIDs: executorIDs, + applicationVersion: _APP_VERSION, }) if err != nil { return nil, err @@ -27,7 +27,7 @@ func recoverPendingWorkflows(ctx context.Context, executorIDs []string) ([]Workf // fmt.Println("Recovering workflow:", workflow.ID, "Name:", workflow.Name, "Input:", workflow.Input, "QueueName:", workflow.QueueName) if workflow.QueueName != "" { - cleared, err := getExecutor().systemDB.ClearQueueAssignment(ctx, workflow.ID) + cleared, err := dbos.systemDB.ClearQueueAssignment(ctx, workflow.ID) if err != nil { getLogger().Error("Error clearing queue assignment for workflow", "workflow_id", workflow.ID, "name", workflow.Name, "error", err) continue @@ -45,7 +45,7 @@ func recoverPendingWorkflows(ctx context.Context, executorIDs []string) ([]Workf } // Convert workflow parameters to options - opts := []WorkflowOption{ + opts := []workflowOption{ WithWorkflowID(workflow.ID), } // XXX we'll figure out the exact timeout/deadline settings later diff --git a/dbos/serialization_test.go b/dbos/serialization_test.go index 5f2e6a88..5040f48e 100644 --- a/dbos/serialization_test.go +++ b/dbos/serialization_test.go @@ -98,8 +98,8 @@ func TestWorkflowEncoding(t *testing.T) { } // Test results from ListWorkflows - workflows, err := getExecutor().systemDB.ListWorkflows(context.Background(), ListWorkflowsDBInput{ - WorkflowIDs: []string{directHandle.GetWorkflowID()}, + workflows, err := dbos.systemDB.ListWorkflows(context.Background(), listWorkflowsDBInput{ + workflowIDs: []string{directHandle.GetWorkflowID()}, }) if err != nil { t.Fatalf("failed to list workflows: %v", err) @@ -136,7 +136,7 @@ func TestWorkflowEncoding(t *testing.T) { } // Test results from GetWorkflowSteps - steps, err := getExecutor().systemDB.GetWorkflowSteps(context.Background(), directHandle.GetWorkflowID()) + steps, err := dbos.systemDB.GetWorkflowSteps(context.Background(), directHandle.GetWorkflowID()) if err != nil { t.Fatalf("failed to get workflow steps: %v", err) } @@ -215,8 +215,8 @@ func TestWorkflowEncoding(t *testing.T) { } // Test results from ListWorkflows - workflows, err := getExecutor().systemDB.ListWorkflows(context.Background(), ListWorkflowsDBInput{ - WorkflowIDs: []string{directHandle.GetWorkflowID()}, + workflows, err := dbos.systemDB.ListWorkflows(context.Background(), listWorkflowsDBInput{ + workflowIDs: []string{directHandle.GetWorkflowID()}, }) if err != nil { t.Fatalf("failed to list workflows: %v", err) @@ -257,7 +257,7 @@ func TestWorkflowEncoding(t *testing.T) { } // Test results from GetWorkflowSteps - steps, err := getExecutor().systemDB.GetWorkflowSteps(context.Background(), directHandle.GetWorkflowID()) + steps, err := dbos.systemDB.GetWorkflowSteps(context.Background(), directHandle.GetWorkflowID()) if err != nil { t.Fatalf("failed to get workflow steps: %v", err) } diff --git a/dbos/system_database.go b/dbos/system_database.go index a4d68b2c..dcabfbf2 100644 --- a/dbos/system_database.go +++ b/dbos/system_database.go @@ -25,16 +25,16 @@ type SystemDatabase interface { Launch(ctx context.Context) Shutdown() ResetSystemDB(ctx context.Context) error - InsertWorkflowStatus(ctx context.Context, input InsertWorkflowStatusDBInput) (*InsertWorkflowResult, error) + InsertWorkflowStatus(ctx context.Context, input insertWorkflowStatusDBInput) (*insertWorkflowResult, error) RecordOperationResult(ctx context.Context, input recordOperationResultDBInput) error RecordChildWorkflow(ctx context.Context, input recordChildWorkflowDBInput) error CheckChildWorkflow(ctx context.Context, workflowUUID string, functionID int) (*string, error) - ListWorkflows(ctx context.Context, input ListWorkflowsDBInput) ([]WorkflowStatus, error) - UpdateWorkflowOutcome(ctx context.Context, input UpdateWorkflowOutcomeDBInput) error + ListWorkflows(ctx context.Context, input listWorkflowsDBInput) ([]WorkflowStatus, error) + UpdateWorkflowOutcome(ctx context.Context, input updateWorkflowOutcomeDBInput) error AwaitWorkflowResult(ctx context.Context, workflowID string) (any, error) DequeueWorkflows(ctx context.Context, queue WorkflowQueue) ([]dequeuedWorkflow, error) ClearQueueAssignment(ctx context.Context, workflowID string) (bool, error) - CheckOperationExecution(ctx context.Context, input CheckOperationExecutionDBInput) (*RecordedResult, error) + CheckOperationExecution(ctx context.Context, input checkOperationExecutionDBInput) (*recordedResult, error) RecordChildGetResult(ctx context.Context, input recordChildGetResultDBInput) error GetWorkflowSteps(ctx context.Context, workflowID string) ([]StepInfo, error) Send(ctx context.Context, input WorkflowSendInput) error @@ -57,19 +57,19 @@ func createDatabaseIfNotExists(databaseURL string) error { // Connect to the postgres database parsedURL, err := pgx.ParseConfig(databaseURL) if err != nil { - return NewInitializationError(fmt.Sprintf("failed to parse database URL: %v", err)) + return newInitializationError(fmt.Sprintf("failed to parse database URL: %v", err)) } dbName := parsedURL.Database if dbName == "" { - return NewInitializationError("database name not found in URL") + return newInitializationError("database name not found in URL") } serverURL := parsedURL.Copy() serverURL.Database = "postgres" conn, err := pgx.ConnectConfig(context.Background(), serverURL) if err != nil { - return NewInitializationError(fmt.Sprintf("failed to connect to PostgreSQL server: %v", err)) + return newInitializationError(fmt.Sprintf("failed to connect to PostgreSQL server: %v", err)) } defer conn.Close(context.Background()) @@ -78,14 +78,14 @@ func createDatabaseIfNotExists(databaseURL string) error { err = conn.QueryRow(context.Background(), "SELECT EXISTS(SELECT 1 FROM pg_database WHERE datname = $1)", dbName).Scan(&exists) if err != nil { - return NewInitializationError(fmt.Sprintf("failed to check if database exists: %v", err)) + return newInitializationError(fmt.Sprintf("failed to check if database exists: %v", err)) } if !exists { // TODO: validate db name createSQL := fmt.Sprintf("CREATE DATABASE %s", pgx.Identifier{dbName}.Sanitize()) _, err = conn.Exec(context.Background(), createSQL) if err != nil { - return NewInitializationError(fmt.Sprintf("failed to create database %s: %v", dbName, err)) + return newInitializationError(fmt.Sprintf("failed to create database %s: %v", dbName, err)) } getLogger().Info("Database created", "name", dbName) } @@ -104,20 +104,20 @@ func runMigrations(databaseURL string) error { // Create migration source from embedded files d, err := iofs.New(migrationFiles, "migrations") if err != nil { - return NewInitializationError(fmt.Sprintf("failed to create migration source: %v", err)) + return newInitializationError(fmt.Sprintf("failed to create migration source: %v", err)) } // Create migrator m, err := migrate.NewWithSourceInstance("iofs", d, databaseURL) if err != nil { - return NewInitializationError(fmt.Sprintf("failed to create migrator: %v", err)) + return newInitializationError(fmt.Sprintf("failed to create migrator: %v", err)) } defer m.Close() // Run migrations // FIXME: tolerate errors when the migration is bcz we run an older version of transact if err := m.Up(); err != nil && err != migrate.ErrNoChange { - return NewInitializationError(fmt.Sprintf("failed to run migrations: %v", err)) + return newInitializationError(fmt.Sprintf("failed to run migrations: %v", err)) } return nil @@ -205,21 +205,21 @@ func (s *systemDatabase) Shutdown() { /******* WORKFLOWS ********/ /*******************************/ -type InsertWorkflowResult struct { - Attempts int `json:"attempts"` - Status WorkflowStatusType `json:"status"` - Name string `json:"name"` - QueueName *string `json:"queue_name"` - WorkflowDeadlineEpochMs *int64 `json:"workflow_deadline_epoch_ms"` +type insertWorkflowResult struct { + attempts int + status WorkflowStatusType + name string + queueName *string + workflowDeadlineEpochMs *int64 } -type InsertWorkflowStatusDBInput struct { +type insertWorkflowStatusDBInput struct { status WorkflowStatus maxRetries int tx pgx.Tx } -func (s *systemDatabase) InsertWorkflowStatus(ctx context.Context, input InsertWorkflowStatusDBInput) (*InsertWorkflowResult, error) { +func (s *systemDatabase) InsertWorkflowStatus(ctx context.Context, input insertWorkflowStatusDBInput) (*insertWorkflowResult, error) { if input.tx == nil { return nil, errors.New("transaction is required for InsertWorkflowStatus") } @@ -284,7 +284,7 @@ func (s *systemDatabase) InsertWorkflowStatus(ctx context.Context, input InsertW END RETURNING recovery_attempts, status, name, queue_name, workflow_deadline_epoch_ms` - var result InsertWorkflowResult + var result insertWorkflowResult err = input.tx.QueryRow(ctx, query, input.status.ID, input.status.Status, @@ -307,27 +307,27 @@ func (s *systemDatabase) InsertWorkflowStatus(ctx context.Context, input InsertW WorkflowStatusEnqueued, WorkflowStatusEnqueued, ).Scan( - &result.Attempts, - &result.Status, - &result.Name, - &result.QueueName, - &result.WorkflowDeadlineEpochMs, + &result.attempts, + &result.status, + &result.name, + &result.queueName, + &result.workflowDeadlineEpochMs, ) if err != nil { return nil, fmt.Errorf("failed to insert workflow status: %w", err) } - if len(input.status.Name) > 0 && result.Name != input.status.Name { - return nil, NewConflictingWorkflowError(input.status.ID, fmt.Sprintf("Workflow already exists with a different name: %s, but the provided name is: %s", result.Name, input.status.Name)) + if len(input.status.Name) > 0 && result.name != input.status.Name { + return nil, newConflictingWorkflowError(input.status.ID, fmt.Sprintf("Workflow already exists with a different name: %s, but the provided name is: %s", result.name, input.status.Name)) } - if len(input.status.QueueName) > 0 && result.QueueName != nil && input.status.QueueName != *result.QueueName { - getLogger().Warn("Queue name conflict for workflow", "workflow_id", input.status.ID, "result_queue", *result.QueueName, "status_queue", input.status.QueueName) + if len(input.status.QueueName) > 0 && result.queueName != nil && input.status.QueueName != *result.queueName { + getLogger().Warn("Queue name conflict for workflow", "workflow_id", input.status.ID, "result_queue", *result.queueName, "status_queue", input.status.QueueName) } // Every time we start executing a workflow (and thus attempt to insert its status), we increment `recovery_attempts` by 1. // When this number becomes equal to `maxRetries + 1`, we mark the workflow as `RETRIES_EXCEEDED`. - if result.Status != WorkflowStatusSuccess && result.Status != WorkflowStatusError && - input.maxRetries > 0 && result.Attempts > input.maxRetries+1 { + if result.status != WorkflowStatusSuccess && result.status != WorkflowStatusError && + input.maxRetries > 0 && result.attempts > input.maxRetries+1 { // Update workflow status to RETRIES_EXCEEDED and clear queue-related fields dlqQuery := `UPDATE dbos.workflow_status @@ -348,32 +348,32 @@ func (s *systemDatabase) InsertWorkflowStatus(ctx context.Context, input InsertW return nil, fmt.Errorf("failed to commit transaction after marking workflow as RETRIES_EXCEEDED: %w", err) } - return nil, NewDeadLetterQueueError(input.status.ID, input.maxRetries) + return nil, newDeadLetterQueueError(input.status.ID, input.maxRetries) } return &result, nil } // ListWorkflowsInput represents the input parameters for listing workflows -type ListWorkflowsDBInput struct { - WorkflowName string - QueueName string - WorkflowIDPrefix string - WorkflowIDs []string - AuthenticatedUser string - StartTime time.Time - EndTime time.Time - Status []WorkflowStatusType - ApplicationVersion string - ExecutorIDs []string - Limit *int - Offset *int - SortDesc bool - Tx pgx.Tx +type listWorkflowsDBInput struct { + workflowName string + queueName string + workflowIDPrefix string + workflowIDs []string + authenticatedUser string + startTime time.Time + endTime time.Time + status []WorkflowStatusType + applicationVersion string + executorIDs []string + limit *int + offset *int + sortDesc bool + tx pgx.Tx } // ListWorkflows retrieves a list of workflows based on the provided filters -func (s *systemDatabase) ListWorkflows(ctx context.Context, input ListWorkflowsDBInput) ([]WorkflowStatus, error) { +func (s *systemDatabase) ListWorkflows(ctx context.Context, input listWorkflowsDBInput) ([]WorkflowStatus, error) { qb := newQueryBuilder() // Build the base query @@ -384,35 +384,35 @@ func (s *systemDatabase) ListWorkflows(ctx context.Context, input ListWorkflowsD FROM dbos.workflow_status` // Add filters using query builder - if input.WorkflowName != "" { - qb.addWhere("name", input.WorkflowName) + if input.workflowName != "" { + qb.addWhere("name", input.workflowName) } - if input.QueueName != "" { - qb.addWhere("queue_name", input.QueueName) + if input.queueName != "" { + qb.addWhere("queue_name", input.queueName) } - if input.WorkflowIDPrefix != "" { - qb.addWhereLike("workflow_uuid", input.WorkflowIDPrefix+"%") + if input.workflowIDPrefix != "" { + qb.addWhereLike("workflow_uuid", input.workflowIDPrefix+"%") } - if len(input.WorkflowIDs) > 0 { - qb.addWhereAny("workflow_uuid", input.WorkflowIDs) + if len(input.workflowIDs) > 0 { + qb.addWhereAny("workflow_uuid", input.workflowIDs) } - if input.AuthenticatedUser != "" { - qb.addWhere("authenticated_user", input.AuthenticatedUser) + if input.authenticatedUser != "" { + qb.addWhere("authenticated_user", input.authenticatedUser) } - if !input.StartTime.IsZero() { - qb.addWhereGreaterEqual("created_at", input.StartTime.UnixMilli()) + if !input.startTime.IsZero() { + qb.addWhereGreaterEqual("created_at", input.startTime.UnixMilli()) } - if !input.EndTime.IsZero() { - qb.addWhereLessEqual("created_at", input.EndTime.UnixMilli()) + if !input.endTime.IsZero() { + qb.addWhereLessEqual("created_at", input.endTime.UnixMilli()) } - if len(input.Status) > 0 { - qb.addWhereAny("status", input.Status) + if len(input.status) > 0 { + qb.addWhereAny("status", input.status) } - if input.ApplicationVersion != "" { - qb.addWhere("application_version", input.ApplicationVersion) + if input.applicationVersion != "" { + qb.addWhere("application_version", input.applicationVersion) } - if len(input.ExecutorIDs) > 0 { - qb.addWhereAny("executor_id", input.ExecutorIDs) + if len(input.executorIDs) > 0 { + qb.addWhereAny("executor_id", input.executorIDs) } // Build complete query @@ -424,31 +424,31 @@ func (s *systemDatabase) ListWorkflows(ctx context.Context, input ListWorkflowsD } // Add sorting - if input.SortDesc { + if input.sortDesc { query += " ORDER BY created_at DESC" } else { query += " ORDER BY created_at ASC" } // Add limit and offset - if input.Limit != nil { + if input.limit != nil { qb.argCounter++ query += fmt.Sprintf(" LIMIT $%d", qb.argCounter) - qb.args = append(qb.args, *input.Limit) + qb.args = append(qb.args, *input.limit) } - if input.Offset != nil { + if input.offset != nil { qb.argCounter++ query += fmt.Sprintf(" OFFSET $%d", qb.argCounter) - qb.args = append(qb.args, *input.Offset) + qb.args = append(qb.args, *input.offset) } // Execute the query var rows pgx.Rows var err error - if input.Tx != nil { - rows, err = input.Tx.Query(ctx, query, qb.args...) + if input.tx != nil { + rows, err = input.tx.Query(ctx, query, qb.args...) } else { rows, err = s.pool.Query(ctx, query, qb.args...) } @@ -529,7 +529,7 @@ func (s *systemDatabase) ListWorkflows(ctx context.Context, input ListWorkflowsD return workflows, nil } -type UpdateWorkflowOutcomeDBInput struct { +type updateWorkflowOutcomeDBInput struct { workflowID string status WorkflowStatusType output any @@ -538,7 +538,7 @@ type UpdateWorkflowOutcomeDBInput struct { } // Will evolve as we serialize all output and error types -func (s *systemDatabase) UpdateWorkflowOutcome(ctx context.Context, input UpdateWorkflowOutcomeDBInput) error { +func (s *systemDatabase) UpdateWorkflowOutcome(ctx context.Context, input updateWorkflowOutcomeDBInput) error { query := `UPDATE dbos.workflow_status SET status = $1, output = $2, error = $3, updated_at = $4, deduplication_id = NULL WHERE workflow_uuid = $5` @@ -573,16 +573,16 @@ func (s *systemDatabase) CancelWorkflow(ctx context.Context, workflowID string) defer tx.Rollback(ctx) // Rollback if not committed // Check if workflow exists - listInput := ListWorkflowsDBInput{ - WorkflowIDs: []string{workflowID}, - Tx: tx, + listInput := listWorkflowsDBInput{ + workflowIDs: []string{workflowID}, + tx: tx, } wfs, err := s.ListWorkflows(ctx, listInput) if err != nil { return err } if len(wfs) == 0 { - return NewNonExistentWorkflowError(workflowID) + return newNonExistentWorkflowError(workflowID) } wf := wfs[0] @@ -640,7 +640,7 @@ func (s *systemDatabase) AwaitWorkflowResult(ctx context.Context, workflowID str } return output, errors.New(*errorStr) case WorkflowStatusCancelled: - return nil, NewAwaitedWorkflowCancelledError(workflowID) + return nil, newAwaitedWorkflowCancelledError(workflowID) default: time.Sleep(1 * time.Second) // Wait before checking again } @@ -819,19 +819,19 @@ func (s *systemDatabase) RecordChildGetResult(ctx context.Context, input recordC /******* STEPS ********/ /*******************************/ -type RecordedResult struct { +type recordedResult struct { output any err error } -type CheckOperationExecutionDBInput struct { +type checkOperationExecutionDBInput struct { workflowID string operationID int functionName string tx pgx.Tx } -func (s *systemDatabase) CheckOperationExecution(ctx context.Context, input CheckOperationExecutionDBInput) (*RecordedResult, error) { +func (s *systemDatabase) CheckOperationExecution(ctx context.Context, input checkOperationExecutionDBInput) (*recordedResult, error) { var tx pgx.Tx var err error @@ -860,14 +860,14 @@ func (s *systemDatabase) CheckOperationExecution(ctx context.Context, input Chec err = tx.QueryRow(ctx, workflowStatusQuery, input.workflowID).Scan(&workflowStatus) if err != nil { if err == pgx.ErrNoRows { - return nil, NewNonExistentWorkflowError(input.workflowID) + return nil, newNonExistentWorkflowError(input.workflowID) } return nil, fmt.Errorf("failed to get workflow status: %w", err) } // If the workflow is cancelled, raise the exception if workflowStatus == WorkflowStatusCancelled { - return nil, NewWorkflowCancelledError(input.workflowID) + return nil, newWorkflowCancelledError(input.workflowID) } // Execute second query to get operation outputs @@ -887,7 +887,7 @@ func (s *systemDatabase) CheckOperationExecution(ctx context.Context, input Chec // If the provided and recorded function name are different, throw an exception if input.functionName != recordedFunctionName { - return nil, NewUnexpectedStepError(input.workflowID, input.operationID, input.functionName, recordedFunctionName) + return nil, newUnexpectedStepError(input.workflowID, input.operationID, input.functionName, recordedFunctionName) } output, err := deserialize(outputString) @@ -899,7 +899,7 @@ func (s *systemDatabase) CheckOperationExecution(ctx context.Context, input Chec if errorStr != nil && *errorStr != "" { recordedError = errors.New(*errorStr) } - result := &RecordedResult{ + result := &recordedResult{ output: output, err: recordedError, } @@ -1003,7 +1003,7 @@ func (s *systemDatabase) notificationListenerLoop(ctx context.Context) { } } -const DBOS_NULL_TOPIC = "__null__topic__" +const _DBOS_NULL_TOPIC = "__null__topic__" // Send is a special type of step that sends a message to another workflow. // Three differences with a normal steps: durability and the function run in the same transaction, and we forbid nested step execution @@ -1013,11 +1013,11 @@ func (s *systemDatabase) Send(ctx context.Context, input WorkflowSendInput) erro // Get workflow state from context workflowState, ok := ctx.Value(WorkflowStateKey).(*WorkflowState) if !ok || workflowState == nil { - return NewStepExecutionError("", functionName, "workflow state not found in context: are you running this step within a workflow?") + return newStepExecutionError("", functionName, "workflow state not found in context: are you running this step within a workflow?") } if workflowState.isWithinStep { - return NewStepExecutionError(workflowState.WorkflowID, functionName, "cannot call Send within a step") + return newStepExecutionError(workflowState.WorkflowID, functionName, "cannot call Send within a step") } stepID := workflowState.NextStepID() @@ -1029,7 +1029,7 @@ func (s *systemDatabase) Send(ctx context.Context, input WorkflowSendInput) erro defer tx.Rollback(ctx) // Check if operation was already executed and do nothing if so - checkInput := CheckOperationExecutionDBInput{ + checkInput := checkOperationExecutionDBInput{ workflowID: input.DestinationID, operationID: stepID, functionName: functionName, @@ -1044,7 +1044,7 @@ func (s *systemDatabase) Send(ctx context.Context, input WorkflowSendInput) erro } // Set default topic if not provided - topic := DBOS_NULL_TOPIC + topic := _DBOS_NULL_TOPIC if len(input.Topic) > 0 { topic = input.Topic } @@ -1062,7 +1062,7 @@ func (s *systemDatabase) Send(ctx context.Context, input WorkflowSendInput) erro if err != nil { // Check for foreign key violation (destination workflow doesn't exist) if pgErr, ok := err.(*pgconn.PgError); ok && pgErr.Code == "23503" { - return NewNonExistentWorkflowError(input.DestinationID) + return newNonExistentWorkflowError(input.DestinationID) } return fmt.Errorf("failed to insert notification: %w", err) } @@ -1098,18 +1098,18 @@ func (s *systemDatabase) Recv(ctx context.Context, input WorkflowRecvInput) (any // XXX these checks might be better suited for outside of the system db code. We'll see when we implement the client. workflowState, ok := ctx.Value(WorkflowStateKey).(*WorkflowState) if !ok || workflowState == nil { - return nil, NewStepExecutionError("", functionName, "workflow state not found in context: are you running this step within a workflow?") + return nil, newStepExecutionError("", functionName, "workflow state not found in context: are you running this step within a workflow?") } if workflowState.isWithinStep { - return nil, NewStepExecutionError(workflowState.WorkflowID, functionName, "cannot call Recv within a step") + return nil, newStepExecutionError(workflowState.WorkflowID, functionName, "cannot call Recv within a step") } stepID := workflowState.NextStepID() destinationID := workflowState.WorkflowID // Set default topic if not provided - topic := DBOS_NULL_TOPIC + topic := _DBOS_NULL_TOPIC if len(input.Topic) > 0 { topic = input.Topic } @@ -1122,7 +1122,7 @@ func (s *systemDatabase) Recv(ctx context.Context, input WorkflowRecvInput) (any // Check if operation was already executed // XXX this might not need to be in the transaction - checkInput := CheckOperationExecutionDBInput{ + checkInput := checkOperationExecutionDBInput{ workflowID: destinationID, operationID: stepID, functionName: functionName, @@ -1146,7 +1146,7 @@ func (s *systemDatabase) Recv(ctx context.Context, input WorkflowRecvInput) (any if loaded { close(c) getLogger().Error("Receive already called for workflow", "destination_id", destinationID) - return nil, NewWorkflowConflictIDError(destinationID) + return nil, newWorkflowConflictIDError(destinationID) } defer func() { // Clean up the channel after we're done @@ -1258,8 +1258,8 @@ func (s *systemDatabase) DequeueWorkflows(ctx context.Context, queue WorkflowQue // First check the rate limiter startTimeMs := time.Now().UnixMilli() var numRecentQueries int - if queue.Limiter != nil { - limiterPeriod := time.Duration(queue.Limiter.Period * float64(time.Second)) + if queue.limiter != nil { + limiterPeriod := time.Duration(queue.limiter.Period * float64(time.Second)) // Calculate the cutoff time: current time minus limiter period cutoffTimeMs := time.Now().Add(-limiterPeriod).UnixMilli() @@ -1273,22 +1273,22 @@ func (s *systemDatabase) DequeueWorkflows(ctx context.Context, queue WorkflowQue AND started_at_epoch_ms > $3` err := tx.QueryRow(ctx, limiterQuery, - queue.Name, + queue.name, WorkflowStatusEnqueued, cutoffTimeMs).Scan(&numRecentQueries) if err != nil { return nil, fmt.Errorf("failed to query rate limiter: %w", err) } - if numRecentQueries >= queue.Limiter.Limit { + if numRecentQueries >= queue.limiter.Limit { return []dequeuedWorkflow{}, nil } } // Calculate max_tasks based on concurrency limits - maxTasks := queue.MaxTasksPerIteration + maxTasks := queue.maxTasksPerIteration - if queue.WorkerConcurrency != nil || queue.GlobalConcurrency != nil { + if queue.workerConcurrency != nil || queue.globalConcurrency != nil { // Count pending workflows by executor pendingQuery := ` SELECT executor_id, COUNT(*) as task_count @@ -1296,7 +1296,7 @@ func (s *systemDatabase) DequeueWorkflows(ctx context.Context, queue WorkflowQue WHERE queue_name = $1 AND status = $2 GROUP BY executor_id` - rows, err := tx.Query(ctx, pendingQuery, queue.Name, WorkflowStatusPending) + rows, err := tx.Query(ctx, pendingQuery, queue.name, WorkflowStatusPending) if err != nil { return nil, fmt.Errorf("failed to query pending workflows: %w", err) } @@ -1312,28 +1312,28 @@ func (s *systemDatabase) DequeueWorkflows(ctx context.Context, queue WorkflowQue pendingWorkflowsDict[executorIDRow] = taskCount } - localPendingWorkflows := pendingWorkflowsDict[EXECUTOR_ID] + localPendingWorkflows := pendingWorkflowsDict[_EXECUTOR_ID] // Check worker concurrency limit - if queue.WorkerConcurrency != nil { - workerConcurrency := *queue.WorkerConcurrency + if queue.workerConcurrency != nil { + workerConcurrency := *queue.workerConcurrency if localPendingWorkflows > workerConcurrency { - getLogger().Warn("Local pending workflows on queue exceeds worker concurrency limit", "local_pending", localPendingWorkflows, "queue_name", queue.Name, "concurrency_limit", workerConcurrency) + getLogger().Warn("Local pending workflows on queue exceeds worker concurrency limit", "local_pending", localPendingWorkflows, "queue_name", queue.name, "concurrency_limit", workerConcurrency) } availableWorkerTasks := max(workerConcurrency-localPendingWorkflows, 0) maxTasks = availableWorkerTasks } // Check global concurrency limit - if queue.GlobalConcurrency != nil { + if queue.globalConcurrency != nil { globalPendingWorkflows := 0 for _, count := range pendingWorkflowsDict { globalPendingWorkflows += count } - concurrency := *queue.GlobalConcurrency + concurrency := *queue.globalConcurrency if globalPendingWorkflows > concurrency { - getLogger().Warn("Total pending workflows on queue exceeds global concurrency limit", "total_pending", globalPendingWorkflows, "queue_name", queue.Name, "concurrency_limit", concurrency) + getLogger().Warn("Total pending workflows on queue exceeds global concurrency limit", "total_pending", globalPendingWorkflows, "queue_name", queue.name, "concurrency_limit", concurrency) } availableTasks := max(concurrency-globalPendingWorkflows, 0) if availableTasks < maxTasks { @@ -1345,7 +1345,7 @@ func (s *systemDatabase) DequeueWorkflows(ctx context.Context, queue WorkflowQue // Build the query to select workflows for dequeueing // Use SKIP LOCKED when no global concurrency is set to avoid blocking, // otherwise use NOWAIT to ensure consistent view across processes - skipLocks := queue.GlobalConcurrency == nil + skipLocks := queue.globalConcurrency == nil var lockClause string if skipLocks { lockClause = "FOR UPDATE SKIP LOCKED" @@ -1354,7 +1354,7 @@ func (s *systemDatabase) DequeueWorkflows(ctx context.Context, queue WorkflowQue } var query string - if queue.PriorityEnabled { + if queue.priorityEnabled { query = fmt.Sprintf(` SELECT workflow_uuid FROM dbos.workflow_status @@ -1379,7 +1379,7 @@ func (s *systemDatabase) DequeueWorkflows(ctx context.Context, queue WorkflowQue } // Execute the query to get workflow IDs - rows, err := tx.Query(ctx, query, queue.Name, WorkflowStatusEnqueued, APP_VERSION) + rows, err := tx.Query(ctx, query, queue.name, WorkflowStatusEnqueued, _APP_VERSION) if err != nil { return nil, fmt.Errorf("failed to query enqueued workflows: %w", err) } @@ -1402,8 +1402,8 @@ func (s *systemDatabase) DequeueWorkflows(ctx context.Context, queue WorkflowQue var retWorkflows []dequeuedWorkflow for _, id := range dequeuedIDs { // If we have a limiter, stop dequeueing workflows when the number of workflows started this period exceeds the limit. - if queue.Limiter != nil { - if len(retWorkflows)+numRecentQueries >= queue.Limiter.Limit { + if queue.limiter != nil { + if len(retWorkflows)+numRecentQueries >= queue.limiter.Limit { break } } @@ -1429,8 +1429,8 @@ func (s *systemDatabase) DequeueWorkflows(ctx context.Context, queue WorkflowQue var inputString *string err := tx.QueryRow(ctx, updateQuery, WorkflowStatusPending, - APP_VERSION, - EXECUTOR_ID, + _APP_VERSION, + _EXECUTOR_ID, startTimeMs, id).Scan(&retWorkflow.name, &inputString) diff --git a/dbos/utils_test.go b/dbos/utils_test.go index 45c2cd7c..9a0c1293 100644 --- a/dbos/utils_test.go +++ b/dbos/utils_test.go @@ -46,11 +46,16 @@ func setupDBOS(t *testing.T) { t.Fatalf("failed to drop test database: %v", err) } - err = Launch() + executor, err := NewExecutor() if err != nil { t.Fatalf("failed to create DBOS instance: %v", err) } + err = executor.Launch() + if err != nil { + t.Fatalf("failed to launch DBOS instance: %v", err) + } + if dbos == nil { t.Fatal("expected DBOS instance but got nil") } @@ -58,7 +63,9 @@ func setupDBOS(t *testing.T) { // Register cleanup to run after test completes t.Cleanup(func() { fmt.Println("Cleaning up DBOS instance...") - Shutdown() + if executor != nil { + executor.Shutdown() + } }) } @@ -141,7 +148,7 @@ func queueEntriesAreCleanedUp() bool { success := false for range maxTries { // Begin transaction - tx, err := getExecutor().systemDB.(*systemDatabase).pool.Begin(context.Background()) + tx, err := dbos.systemDB.(*systemDatabase).pool.Begin(context.Background()) if err != nil { return false } diff --git a/dbos/workflow.go b/dbos/workflow.go index 9afc6d14..0e15d7fa 100644 --- a/dbos/workflow.go +++ b/dbos/workflow.go @@ -55,6 +55,7 @@ type WorkflowStatus struct { } // WorkflowState holds the runtime state for a workflow execution +// TODO: this should be an internal type. Workflows should have aptly named getters to access the state type WorkflowState struct { WorkflowID string stepCounter int @@ -102,7 +103,7 @@ func (h *workflowHandle[R]) GetResult(ctx context.Context) (R, error) { if isChildWorkflow { encodedOutput, encErr := serialize(outcome.result) if encErr != nil { - return *new(R), NewWorkflowExecutionError(parentWorkflowState.WorkflowID, fmt.Sprintf("serializing child workflow result: %v", encErr)) + return *new(R), newWorkflowExecutionError(parentWorkflowState.WorkflowID, fmt.Sprintf("serializing child workflow result: %v", encErr)) } recordGetResultInput := recordChildGetResultDBInput{ parentWorkflowID: parentWorkflowState.WorkflowID, @@ -111,7 +112,7 @@ func (h *workflowHandle[R]) GetResult(ctx context.Context) (R, error) { output: encodedOutput, err: outcome.err, } - recordResultErr := getExecutor().systemDB.RecordChildGetResult(ctx, recordGetResultInput) + recordResultErr := dbos.systemDB.RecordChildGetResult(ctx, recordGetResultInput) if recordResultErr != nil { // XXX do we want to fail this? getLogger().Error("failed to record get result", "error", recordResultErr) @@ -123,14 +124,14 @@ func (h *workflowHandle[R]) GetResult(ctx context.Context) (R, error) { // GetStatus returns the current status of the workflow from the database func (h *workflowHandle[R]) GetStatus() (WorkflowStatus, error) { ctx := context.Background() - workflowStatuses, err := getExecutor().systemDB.ListWorkflows(ctx, ListWorkflowsDBInput{ - WorkflowIDs: []string{h.workflowID}, + workflowStatuses, err := dbos.systemDB.ListWorkflows(ctx, listWorkflowsDBInput{ + workflowIDs: []string{h.workflowID}, }) if err != nil { return WorkflowStatus{}, fmt.Errorf("failed to get workflow status: %w", err) } if len(workflowStatuses) == 0 { - return WorkflowStatus{}, NewNonExistentWorkflowError(h.workflowID) + return WorkflowStatus{}, newNonExistentWorkflowError(h.workflowID) } return workflowStatuses[0], nil } @@ -144,12 +145,12 @@ type workflowPollingHandle[R any] struct { } func (h *workflowPollingHandle[R]) GetResult(ctx context.Context) (R, error) { - result, err := getExecutor().systemDB.AwaitWorkflowResult(ctx, h.workflowID) + result, err := dbos.systemDB.AwaitWorkflowResult(ctx, h.workflowID) if result != nil { typedResult, ok := result.(R) if !ok { // TODO check what this looks like in practice - return *new(R), NewWorkflowUnexpectedResultType(h.workflowID, fmt.Sprintf("%T", new(R)), fmt.Sprintf("%T", result)) + return *new(R), newWorkflowUnexpectedResultType(h.workflowID, fmt.Sprintf("%T", new(R)), fmt.Sprintf("%T", result)) } // If we are calling GetResult inside a workflow, record the result as a step result parentWorkflowState, ok := ctx.Value(WorkflowStateKey).(*WorkflowState) @@ -157,7 +158,7 @@ func (h *workflowPollingHandle[R]) GetResult(ctx context.Context) (R, error) { if isChildWorkflow { encodedOutput, encErr := serialize(typedResult) if encErr != nil { - return *new(R), NewWorkflowExecutionError(parentWorkflowState.WorkflowID, fmt.Sprintf("serializing child workflow result: %v", encErr)) + return *new(R), newWorkflowExecutionError(parentWorkflowState.WorkflowID, fmt.Sprintf("serializing child workflow result: %v", encErr)) } recordGetResultInput := recordChildGetResultDBInput{ parentWorkflowID: parentWorkflowState.WorkflowID, @@ -166,7 +167,7 @@ func (h *workflowPollingHandle[R]) GetResult(ctx context.Context) (R, error) { output: encodedOutput, err: err, } - recordResultErr := getExecutor().systemDB.RecordChildGetResult(ctx, recordGetResultInput) + recordResultErr := dbos.systemDB.RecordChildGetResult(ctx, recordGetResultInput) if recordResultErr != nil { // XXX do we want to fail this? getLogger().Error("failed to record get result", "error", recordResultErr) @@ -180,14 +181,14 @@ func (h *workflowPollingHandle[R]) GetResult(ctx context.Context) (R, error) { // GetStatus returns the current status of the workflow from the database func (h *workflowPollingHandle[R]) GetStatus() (WorkflowStatus, error) { ctx := context.Background() - workflowStatuses, err := getExecutor().systemDB.ListWorkflows(ctx, ListWorkflowsDBInput{ - WorkflowIDs: []string{h.workflowID}, + workflowStatuses, err := dbos.systemDB.ListWorkflows(ctx, listWorkflowsDBInput{ + workflowIDs: []string{h.workflowID}, }) if err != nil { return WorkflowStatus{}, fmt.Errorf("failed to get workflow status: %w", err) } if len(workflowStatuses) == 0 { - return WorkflowStatus{}, NewNonExistentWorkflowError(h.workflowID) + return WorkflowStatus{}, newNonExistentWorkflowError(h.workflowID) } return workflowStatuses[0], nil } @@ -199,7 +200,7 @@ func (h *workflowPollingHandle[R]) GetWorkflowID() string { /**********************************/ /******* WORKFLOW REGISTRY *******/ /**********************************/ -type TypedErasedWorkflowWrapperFunc func(ctx context.Context, input any, opts ...WorkflowOption) (WorkflowHandle[any], error) +type TypedErasedWorkflowWrapperFunc func(ctx context.Context, input any, opts ...workflowOption) (WorkflowHandle[any], error) type workflowRegistryEntry struct { wrappedFunction TypedErasedWorkflowWrapperFunc @@ -216,7 +217,7 @@ func registerWorkflow(fqn string, fn TypedErasedWorkflowWrapperFunc, maxRetries if _, exists := registry[fqn]; exists { getLogger().Error("workflow function already registered", "fqn", fqn) - panic(NewConflictingRegistrationError(fqn)) + panic(newConflictingRegistrationError(fqn)) } registry[fqn] = workflowRegistryEntry{ @@ -225,38 +226,38 @@ func registerWorkflow(fqn string, fn TypedErasedWorkflowWrapperFunc, maxRetries } } -type WorkflowRegistrationParams struct { - CronSchedule string - MaxRetries int +type workflowRegistrationParams struct { + cronSchedule string + maxRetries int // Likely we will allow a name here } -type WorkflowRegistrationOption func(*WorkflowRegistrationParams) +type workflowRegistrationOption func(*workflowRegistrationParams) const ( - DEFAULT_MAX_RECOVERY_ATTEMPTS = 100 + _DEFAULT_MAX_RECOVERY_ATTEMPTS = 100 ) -func WithMaxRetries(maxRetries int) WorkflowRegistrationOption { - return func(p *WorkflowRegistrationParams) { - p.MaxRetries = maxRetries +func WithMaxRetries(maxRetries int) workflowRegistrationOption { + return func(p *workflowRegistrationParams) { + p.maxRetries = maxRetries } } -func WithSchedule(schedule string) WorkflowRegistrationOption { - return func(p *WorkflowRegistrationParams) { - p.CronSchedule = schedule +func WithSchedule(schedule string) workflowRegistrationOption { + return func(p *workflowRegistrationParams) { + p.cronSchedule = schedule } } -func WithWorkflow[P any, R any](fn WorkflowFunc[P, R], opts ...WorkflowRegistrationOption) WorkflowWrapperFunc[P, R] { - if getExecutor() != nil { +func WithWorkflow[P any, R any](fn WorkflowFunc[P, R], opts ...workflowRegistrationOption) WorkflowWrapperFunc[P, R] { + if dbos != nil { getLogger().Warn("WithWorkflow called after DBOS initialization, dynamic registration is not supported") return nil } - registrationParams := WorkflowRegistrationParams{ - MaxRetries: DEFAULT_MAX_RECOVERY_ATTEMPTS, + registrationParams := workflowRegistrationParams{ + maxRetries: _DEFAULT_MAX_RECOVERY_ATTEMPTS, } for _, opt := range opts { opt(®istrationParams) @@ -274,13 +275,13 @@ func WithWorkflow[P any, R any](fn WorkflowFunc[P, R], opts ...WorkflowRegistrat gob.Register(r) // Wrap the function in a durable workflow - wrappedFunction := WorkflowWrapperFunc[P, R](func(ctx context.Context, workflowInput P, opts ...WorkflowOption) (WorkflowHandle[R], error) { - opts = append(opts, WithWorkflowMaxRetries(registrationParams.MaxRetries)) + wrappedFunction := WorkflowWrapperFunc[P, R](func(ctx context.Context, workflowInput P, opts ...workflowOption) (WorkflowHandle[R], error) { + opts = append(opts, WithWorkflowMaxRetries(registrationParams.maxRetries)) return runAsWorkflow(ctx, fn, workflowInput, opts...) }) // If this is a scheduled workflow, register a cron job - if registrationParams.CronSchedule != "" { + if registrationParams.cronSchedule != "" { if reflect.TypeOf(p) != reflect.TypeOf(time.Time{}) { panic(fmt.Sprintf("scheduled workflow function must accept ScheduledWorkflowInput as input, got %T", p)) } @@ -288,9 +289,9 @@ func WithWorkflow[P any, R any](fn WorkflowFunc[P, R], opts ...WorkflowRegistrat workflowScheduler = cron.New(cron.WithSeconds()) } var entryID cron.EntryID - entryID, err := workflowScheduler.AddFunc(registrationParams.CronSchedule, func() { + entryID, err := workflowScheduler.AddFunc(registrationParams.cronSchedule, func() { // Execute the workflow on the cron schedule once DBOS is launched - if getExecutor() == nil { + if dbos == nil { return } // Get the scheduled time from the cron entry @@ -309,10 +310,10 @@ func WithWorkflow[P any, R any](fn WorkflowFunc[P, R], opts ...WorkflowRegistrat } // Register a type-erased version of the durable workflow for recovery - typeErasedWrapper := func(ctx context.Context, input any, opts ...WorkflowOption) (WorkflowHandle[any], error) { + typeErasedWrapper := func(ctx context.Context, input any, opts ...workflowOption) (WorkflowHandle[any], error) { typedInput, ok := input.(P) if !ok { - return nil, NewWorkflowUnexpectedInputType(fqn, fmt.Sprintf("%T", typedInput), fmt.Sprintf("%T", input)) + return nil, newWorkflowUnexpectedInputType(fqn, fmt.Sprintf("%T", typedInput), fmt.Sprintf("%T", input)) } handle, err := wrappedFunction(ctx, typedInput, opts...) @@ -321,7 +322,7 @@ func WithWorkflow[P any, R any](fn WorkflowFunc[P, R], opts ...WorkflowRegistrat } return &workflowPollingHandle[any]{workflowID: handle.GetWorkflowID()}, nil } - registerWorkflow(fqn, typeErasedWrapper, registrationParams.MaxRetries) + registerWorkflow(fqn, typeErasedWrapper, registrationParams.maxRetries) return wrappedFunction } @@ -332,62 +333,63 @@ func WithWorkflow[P any, R any](fn WorkflowFunc[P, R], opts ...WorkflowRegistrat type contextKey string +// TODO this should be a private type, once we have proper getter for a workflow state const WorkflowStateKey contextKey = "workflowState" type WorkflowFunc[P any, R any] func(ctx context.Context, input P) (R, error) -type WorkflowWrapperFunc[P any, R any] func(ctx context.Context, input P, opts ...WorkflowOption) (WorkflowHandle[R], error) +type WorkflowWrapperFunc[P any, R any] func(ctx context.Context, input P, opts ...workflowOption) (WorkflowHandle[R], error) -type WorkflowParams struct { - WorkflowID string - Timeout time.Duration - Deadline time.Time - QueueName string - ApplicationVersion string - MaxRetries int +type workflowParams struct { + workflowID string + timeout time.Duration + deadline time.Time + queueName string + applicationVersion string + maxRetries int } -type WorkflowOption func(*WorkflowParams) +type workflowOption func(*workflowParams) -func WithWorkflowID(id string) WorkflowOption { - return func(p *WorkflowParams) { - p.WorkflowID = id +func WithWorkflowID(id string) workflowOption { + return func(p *workflowParams) { + p.workflowID = id } } -func WithTimeout(timeout time.Duration) WorkflowOption { - return func(p *WorkflowParams) { - p.Timeout = timeout +func WithTimeout(timeout time.Duration) workflowOption { + return func(p *workflowParams) { + p.timeout = timeout } } -func WithDeadline(deadline time.Time) WorkflowOption { - return func(p *WorkflowParams) { - p.Deadline = deadline +func WithDeadline(deadline time.Time) workflowOption { + return func(p *workflowParams) { + p.deadline = deadline } } -func WithQueue(queueName string) WorkflowOption { - return func(p *WorkflowParams) { - p.QueueName = queueName +func WithQueue(queueName string) workflowOption { + return func(p *workflowParams) { + p.queueName = queueName } } -func WithApplicationVersion(version string) WorkflowOption { - return func(p *WorkflowParams) { - p.ApplicationVersion = version +func WithApplicationVersion(version string) workflowOption { + return func(p *workflowParams) { + p.applicationVersion = version } } -func WithWorkflowMaxRetries(maxRetries int) WorkflowOption { - return func(p *WorkflowParams) { - p.MaxRetries = maxRetries +func WithWorkflowMaxRetries(maxRetries int) workflowOption { + return func(p *workflowParams) { + p.maxRetries = maxRetries } } -func runAsWorkflow[P any, R any](ctx context.Context, fn WorkflowFunc[P, R], input P, opts ...WorkflowOption) (WorkflowHandle[R], error) { +func runAsWorkflow[P any, R any](ctx context.Context, fn WorkflowFunc[P, R], input P, opts ...workflowOption) (WorkflowHandle[R], error) { // Apply options to build params - params := WorkflowParams{ - ApplicationVersion: APP_VERSION, + params := workflowParams{ + applicationVersion: _APP_VERSION, } for _, opt := range opts { opt(¶ms) @@ -404,7 +406,7 @@ func runAsWorkflow[P any, R any](ctx context.Context, fn WorkflowFunc[P, R], inp // Generate an ID for the workflow if not provided var workflowID string - if params.WorkflowID == "" { + if params.workflowID == "" { if isChildWorkflow { stepID := parentWorkflowState.NextStepID() workflowID = fmt.Sprintf("%s-%d", parentWorkflowState.WorkflowID, stepID) @@ -412,14 +414,14 @@ func runAsWorkflow[P any, R any](ctx context.Context, fn WorkflowFunc[P, R], inp workflowID = uuid.New().String() } } else { - workflowID = params.WorkflowID + workflowID = params.workflowID } // If this is a child workflow that has already been recorded in operations_output, return directly a polling handle if isChildWorkflow { - childWorkflowID, err := getExecutor().systemDB.CheckChildWorkflow(dbosWorkflowContext, parentWorkflowState.WorkflowID, parentWorkflowState.stepCounter) + childWorkflowID, err := dbos.systemDB.CheckChildWorkflow(dbosWorkflowContext, parentWorkflowState.WorkflowID, parentWorkflowState.stepCounter) if err != nil { - return nil, NewWorkflowExecutionError(parentWorkflowState.WorkflowID, fmt.Sprintf("checking child workflow: %v", err)) + return nil, newWorkflowExecutionError(parentWorkflowState.WorkflowID, fmt.Sprintf("checking child workflow: %v", err)) } if childWorkflowID != nil { return &workflowPollingHandle[R]{workflowID: *childWorkflowID}, nil @@ -427,7 +429,7 @@ func runAsWorkflow[P any, R any](ctx context.Context, fn WorkflowFunc[P, R], inp } var status WorkflowStatusType - if params.QueueName != "" { + if params.queueName != "" { status = WorkflowStatusEnqueued } else { status = WorkflowStatusPending @@ -435,41 +437,41 @@ func runAsWorkflow[P any, R any](ctx context.Context, fn WorkflowFunc[P, R], inp workflowStatus := WorkflowStatus{ Name: runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name(), // TODO factor out somewhere else so we dont' have to reflect here - ApplicationVersion: params.ApplicationVersion, - ExecutorID: EXECUTOR_ID, + ApplicationVersion: params.applicationVersion, + ExecutorID: _EXECUTOR_ID, Status: status, ID: workflowID, CreatedAt: time.Now(), - Deadline: params.Deadline, // TODO compute the deadline based on the timeout - Timeout: params.Timeout, + Deadline: params.deadline, // TODO compute the deadline based on the timeout + Timeout: params.timeout, Input: input, - ApplicationID: APP_ID, - QueueName: params.QueueName, + ApplicationID: _APP_ID, + QueueName: params.queueName, } // Init status and record child workflow relationship in a single transaction - tx, err := getExecutor().systemDB.(*systemDatabase).pool.Begin(dbosWorkflowContext) + tx, err := dbos.systemDB.(*systemDatabase).pool.Begin(dbosWorkflowContext) if err != nil { - return nil, NewWorkflowExecutionError(workflowID, fmt.Sprintf("failed to begin transaction: %v", err)) + return nil, newWorkflowExecutionError(workflowID, fmt.Sprintf("failed to begin transaction: %v", err)) } defer tx.Rollback(dbosWorkflowContext) // Rollback if not committed // Insert workflow status with transaction - insertInput := InsertWorkflowStatusDBInput{ + insertInput := insertWorkflowStatusDBInput{ status: workflowStatus, - maxRetries: params.MaxRetries, + maxRetries: params.maxRetries, tx: tx, } - insertStatusResult, err := getExecutor().systemDB.InsertWorkflowStatus(dbosWorkflowContext, insertInput) + insertStatusResult, err := dbos.systemDB.InsertWorkflowStatus(dbosWorkflowContext, insertInput) if err != nil { return nil, err } // Return a polling handle if: we are enqueueing, the workflow is already in a terminal state (success or error), - if len(params.QueueName) > 0 || insertStatusResult.Status == WorkflowStatusSuccess || insertStatusResult.Status == WorkflowStatusError { + if len(params.queueName) > 0 || insertStatusResult.status == WorkflowStatusSuccess || insertStatusResult.status == WorkflowStatusError { // Commit the transaction to update the number of attempts and/or enact the enqueue if err := tx.Commit(dbosWorkflowContext); err != nil { - return nil, NewWorkflowExecutionError(workflowID, fmt.Sprintf("failed to commit transaction: %v", err)) + return nil, newWorkflowExecutionError(workflowID, fmt.Sprintf("failed to commit transaction: %v", err)) } return &workflowPollingHandle[R]{workflowID: workflowStatus.ID}, nil } @@ -485,15 +487,15 @@ func runAsWorkflow[P any, R any](ctx context.Context, fn WorkflowFunc[P, R], inp functionID: stepID, tx: tx, } - err = getExecutor().systemDB.RecordChildWorkflow(dbosWorkflowContext, childInput) + err = dbos.systemDB.RecordChildWorkflow(dbosWorkflowContext, childInput) if err != nil { - return nil, NewWorkflowExecutionError(parentWorkflowState.WorkflowID, fmt.Sprintf("recording child workflow: %v", err)) + return nil, newWorkflowExecutionError(parentWorkflowState.WorkflowID, fmt.Sprintf("recording child workflow: %v", err)) } } // Commit the transaction if err := tx.Commit(dbosWorkflowContext); err != nil { - return nil, NewWorkflowExecutionError(workflowID, fmt.Sprintf("failed to commit transaction: %v", err)) + return nil, newWorkflowExecutionError(workflowID, fmt.Sprintf("failed to commit transaction: %v", err)) } // Channel to receive the outcome from the goroutine @@ -521,7 +523,7 @@ func runAsWorkflow[P any, R any](ctx context.Context, fn WorkflowFunc[P, R], inp if err != nil { status = WorkflowStatusError } - recordErr := getExecutor().systemDB.UpdateWorkflowOutcome(dbosWorkflowContext, UpdateWorkflowOutcomeDBInput{workflowID: workflowStatus.ID, status: status, err: err, output: result}) + recordErr := dbos.systemDB.UpdateWorkflowOutcome(dbosWorkflowContext, updateWorkflowOutcomeDBInput{workflowID: workflowStatus.ID, status: status, err: err, output: result}) if recordErr != nil { outcomeChan <- workflowOutcome[R]{result: *new(R), err: recordErr} close(outcomeChan) // Close the channel to signal completion @@ -562,40 +564,40 @@ type StepParams struct { MaxInterval time.Duration } -// StepOption is a functional option for configuring step parameters -type StepOption func(*StepParams) +// stepOption is a functional option for configuring step parameters +type stepOption func(*StepParams) // WithStepMaxRetries sets the maximum number of retries for a step -func WithStepMaxRetries(maxRetries int) StepOption { +func WithStepMaxRetries(maxRetries int) stepOption { return func(p *StepParams) { p.MaxRetries = maxRetries } } // WithBackoffFactor sets the backoff factor for retries (multiplier for exponential backoff) -func WithBackoffFactor(backoffFactor float64) StepOption { +func WithBackoffFactor(backoffFactor float64) stepOption { return func(p *StepParams) { p.BackoffFactor = backoffFactor } } // WithBaseInterval sets the base delay for the first retry -func WithBaseInterval(baseInterval time.Duration) StepOption { +func WithBaseInterval(baseInterval time.Duration) stepOption { return func(p *StepParams) { p.BaseInterval = baseInterval } } // WithMaxInterval sets the maximum delay for retries -func WithMaxInterval(maxInterval time.Duration) StepOption { +func WithMaxInterval(maxInterval time.Duration) stepOption { return func(p *StepParams) { p.MaxInterval = maxInterval } } -func RunAsStep[P any, R any](ctx context.Context, fn StepFunc[P, R], input P, opts ...StepOption) (R, error) { +func RunAsStep[P any, R any](ctx context.Context, fn StepFunc[P, R], input P, opts ...stepOption) (R, error) { if fn == nil { - return *new(R), NewStepExecutionError("", "", "step function cannot be nil") + return *new(R), newStepExecutionError("", "", "step function cannot be nil") } operationName := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name() @@ -614,7 +616,7 @@ func RunAsStep[P any, R any](ctx context.Context, fn StepFunc[P, R], input P, op // Get workflow state from context workflowState, ok := ctx.Value(WorkflowStateKey).(*WorkflowState) if !ok || workflowState == nil { - return *new(R), NewStepExecutionError("", operationName, "workflow state not found in context: are you running this step within a workflow?") + return *new(R), newStepExecutionError("", operationName, "workflow state not found in context: are you running this step within a workflow?") } // If within a step, just run the function directly @@ -626,13 +628,13 @@ func RunAsStep[P any, R any](ctx context.Context, fn StepFunc[P, R], input P, op operationID := workflowState.NextStepID() // Check the step is cancelled, has already completed, or is called with a different name - recordedOutput, err := getExecutor().systemDB.CheckOperationExecution(ctx, CheckOperationExecutionDBInput{ + recordedOutput, err := dbos.systemDB.CheckOperationExecution(ctx, checkOperationExecutionDBInput{ workflowID: workflowState.WorkflowID, operationID: operationID, functionName: runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name(), }) if err != nil { - return *new(R), NewStepExecutionError(workflowState.WorkflowID, operationName, fmt.Sprintf("checking operation execution: %v", err)) + return *new(R), newStepExecutionError(workflowState.WorkflowID, operationName, fmt.Sprintf("checking operation execution: %v", err)) } if recordedOutput != nil { return recordedOutput.output.(R), recordedOutput.err @@ -666,7 +668,7 @@ func RunAsStep[P any, R any](ctx context.Context, fn StepFunc[P, R], input P, op // Wait before retry select { case <-ctx.Done(): - return *new(R), NewStepExecutionError(workflowState.WorkflowID, operationName, fmt.Sprintf("context cancelled during retry: %v", ctx.Err())) + return *new(R), newStepExecutionError(workflowState.WorkflowID, operationName, fmt.Sprintf("context cancelled during retry: %v", ctx.Err())) case <-time.After(delay): // Continue to retry } @@ -684,7 +686,7 @@ func RunAsStep[P any, R any](ctx context.Context, fn StepFunc[P, R], input P, op // If max retries reached, create MaxStepRetriesExceeded error if retry == params.MaxRetries { - stepError = NewMaxStepRetriesExceededError(workflowState.WorkflowID, operationName, params.MaxRetries, joinedErrors) + stepError = newMaxStepRetriesExceededError(workflowState.WorkflowID, operationName, params.MaxRetries, joinedErrors) break } } @@ -698,9 +700,9 @@ func RunAsStep[P any, R any](ctx context.Context, fn StepFunc[P, R], input P, op err: stepError, output: stepOutput, } - recErr := getExecutor().systemDB.RecordOperationResult(ctx, dbInput) + recErr := dbos.systemDB.RecordOperationResult(ctx, dbInput) if recErr != nil { - return *new(R), NewStepExecutionError(workflowState.WorkflowID, operationName, fmt.Sprintf("recording step outcome: %v", recErr)) + return *new(R), newStepExecutionError(workflowState.WorkflowID, operationName, fmt.Sprintf("recording step outcome: %v", recErr)) } return stepOutput, stepError @@ -717,7 +719,7 @@ type WorkflowSendInput struct { } func Send(ctx context.Context, input WorkflowSendInput) error { - return getExecutor().systemDB.Send(ctx, input) + return dbos.systemDB.Send(ctx, input) } type WorkflowRecvInput struct { @@ -726,7 +728,7 @@ type WorkflowRecvInput struct { } func Recv[R any](ctx context.Context, input WorkflowRecvInput) (R, error) { - msg, err := getExecutor().systemDB.Recv(ctx, input) + msg, err := dbos.systemDB.Recv(ctx, input) if err != nil { return *new(R), err } @@ -736,7 +738,7 @@ func Recv[R any](ctx context.Context, input WorkflowRecvInput) (R, error) { var ok bool typedMessage, ok = msg.(R) if !ok { - return *new(R), NewWorkflowUnexpectedResultType("", fmt.Sprintf("%T", new(R)), fmt.Sprintf("%T", msg)) + return *new(R), newWorkflowUnexpectedResultType("", fmt.Sprintf("%T", new(R)), fmt.Sprintf("%T", msg)) } } return typedMessage, nil @@ -748,14 +750,14 @@ func Recv[R any](ctx context.Context, input WorkflowRecvInput) (R, error) { func RetrieveWorkflow[R any](workflowID string) (workflowPollingHandle[R], error) { ctx := context.Background() - workflowStatus, err := getExecutor().systemDB.ListWorkflows(ctx, ListWorkflowsDBInput{ - WorkflowIDs: []string{workflowID}, + workflowStatus, err := dbos.systemDB.ListWorkflows(ctx, listWorkflowsDBInput{ + workflowIDs: []string{workflowID}, }) if err != nil { return workflowPollingHandle[R]{}, fmt.Errorf("failed to retrieve workflow status: %w", err) } if len(workflowStatus) == 0 { - return workflowPollingHandle[R]{}, NewNonExistentWorkflowError(workflowID) + return workflowPollingHandle[R]{}, newNonExistentWorkflowError(workflowID) } return workflowPollingHandle[R]{workflowID: workflowID}, nil } diff --git a/dbos/workflows_test.go b/dbos/workflows_test.go index a3ffb8f1..cc396785 100644 --- a/dbos/workflows_test.go +++ b/dbos/workflows_test.go @@ -116,7 +116,7 @@ var ( ) func TestAppVersion(t *testing.T) { - if _, err := hex.DecodeString(APP_VERSION); err != nil { + if _, err := hex.DecodeString(_APP_VERSION); err != nil { t.Fatalf("APP_VERSION is not a valid hex string: %v", err) } @@ -136,7 +136,7 @@ func TestAppVersion(t *testing.T) { return "new-registry-workflow-" + input, nil }) hash2 := computeApplicationVersion() - if APP_VERSION == hash2 { + if _APP_VERSION == hash2 { t.Fatalf("APP_VERSION hash did not change after replacing registry") } } @@ -146,7 +146,7 @@ func TestWorkflowsWrapping(t *testing.T) { type testCase struct { name string - workflowFunc func(context.Context, string, ...WorkflowOption) (any, error) + workflowFunc func(context.Context, string, ...workflowOption) (any, error) input string expectedResult any expectError bool @@ -156,7 +156,7 @@ func TestWorkflowsWrapping(t *testing.T) { tests := []testCase{ { name: "SimpleWorkflow", - workflowFunc: func(ctx context.Context, input string, opts ...WorkflowOption) (any, error) { + workflowFunc: func(ctx context.Context, input string, opts ...workflowOption) (any, error) { handle, err := simpleWf(ctx, input, opts...) if err != nil { return nil, err @@ -178,7 +178,7 @@ func TestWorkflowsWrapping(t *testing.T) { }, { name: "SimpleWorkflowError", - workflowFunc: func(ctx context.Context, input string, opts ...WorkflowOption) (any, error) { + workflowFunc: func(ctx context.Context, input string, opts ...workflowOption) (any, error) { handle, err := simpleWfError(ctx, input, opts...) if err != nil { return nil, err @@ -191,7 +191,7 @@ func TestWorkflowsWrapping(t *testing.T) { }, { name: "SimpleWorkflowWithStep", - workflowFunc: func(ctx context.Context, input string, opts ...WorkflowOption) (any, error) { + workflowFunc: func(ctx context.Context, input string, opts ...workflowOption) (any, error) { handle, err := simpleWfWithStep(ctx, input, opts...) if err != nil { return nil, err @@ -204,7 +204,7 @@ func TestWorkflowsWrapping(t *testing.T) { }, { name: "SimpleWorkflowStruct", - workflowFunc: func(ctx context.Context, input string, opts ...WorkflowOption) (any, error) { + workflowFunc: func(ctx context.Context, input string, opts ...workflowOption) (any, error) { handle, err := simpleWfStruct(ctx, input, opts...) if err != nil { return nil, err @@ -217,7 +217,7 @@ func TestWorkflowsWrapping(t *testing.T) { }, { name: "ValueReceiverWorkflow", - workflowFunc: func(ctx context.Context, input string, opts ...WorkflowOption) (any, error) { + workflowFunc: func(ctx context.Context, input string, opts ...workflowOption) (any, error) { handle, err := simpleWfValue(ctx, input, opts...) if err != nil { return nil, err @@ -230,7 +230,7 @@ func TestWorkflowsWrapping(t *testing.T) { }, { name: "interfaceMethodWorkflow", - workflowFunc: func(ctx context.Context, input string, opts ...WorkflowOption) (any, error) { + workflowFunc: func(ctx context.Context, input string, opts ...workflowOption) (any, error) { handle, err := simpleWfIface(ctx, input, opts...) if err != nil { return nil, err @@ -243,7 +243,7 @@ func TestWorkflowsWrapping(t *testing.T) { }, { name: "GenericWorkflow", - workflowFunc: func(ctx context.Context, input string, opts ...WorkflowOption) (any, error) { + workflowFunc: func(ctx context.Context, input string, opts ...workflowOption) (any, error) { // For generic workflow, we need to convert string to int for testing handle, err := wfInt(ctx, "42", opts...) // FIXME for now this returns a string because sys db accepts this if err != nil { @@ -257,7 +257,7 @@ func TestWorkflowsWrapping(t *testing.T) { }, { name: "ClosureWithCapturedState", - workflowFunc: func(ctx context.Context, input string, opts ...WorkflowOption) (any, error) { + workflowFunc: func(ctx context.Context, input string, opts ...workflowOption) (any, error) { handle, err := wfClose(ctx, input, opts...) if err != nil { return nil, err @@ -270,7 +270,7 @@ func TestWorkflowsWrapping(t *testing.T) { }, { name: "AnonymousClosure", - workflowFunc: func(ctx context.Context, input string, opts ...WorkflowOption) (any, error) { + workflowFunc: func(ctx context.Context, input string, opts ...workflowOption) (any, error) { handle, err := anonymousWf(ctx, input, opts...) if err != nil { return nil, err @@ -283,7 +283,7 @@ func TestWorkflowsWrapping(t *testing.T) { }, { name: "SimpleWorkflowWithStepError", - workflowFunc: func(ctx context.Context, input string, opts ...WorkflowOption) (any, error) { + workflowFunc: func(ctx context.Context, input string, opts ...workflowOption) (any, error) { handle, err := simpleWfWithStepError(ctx, input, opts...) if err != nil { return nil, err @@ -390,7 +390,7 @@ func TestSteps(t *testing.T) { t.Fatalf("expected result 'from step', got '%s'", result) } - steps, err := getExecutor().systemDB.GetWorkflowSteps(context.Background(), handle.GetWorkflowID()) + steps, err := dbos.systemDB.GetWorkflowSteps(context.Background(), handle.GetWorkflowID()) if err != nil { t.Fatal("failed to list steps:", err) } @@ -444,7 +444,7 @@ func TestSteps(t *testing.T) { } // Verify that the failed step was still recorded in the database - steps, err := getExecutor().systemDB.GetWorkflowSteps(context.Background(), handle.GetWorkflowID()) + steps, err := dbos.systemDB.GetWorkflowSteps(context.Background(), handle.GetWorkflowID()) if err != nil { t.Fatal("failed to get workflow steps:", err) } @@ -684,8 +684,8 @@ func TestWorkflowRecovery(t *testing.T) { } // Using ListWorkflows, retrieve the status of the workflow - workflows, err := getExecutor().systemDB.ListWorkflows(context.Background(), ListWorkflowsDBInput{ - WorkflowIDs: []string{handle1.GetWorkflowID()}, + workflows, err := dbos.systemDB.ListWorkflows(context.Background(), listWorkflowsDBInput{ + workflowIDs: []string{handle1.GetWorkflowID()}, }) if err != nil { t.Fatalf("failed to list workflows: %v", err) @@ -867,7 +867,7 @@ func TestWorkflowDeadLetterQueue(t *testing.T) { deadLetterQueueStartEvent.Clear() // Attempt to recover the blocked workflow many times (should never fail) handles := []WorkflowHandle[any]{} - for i := range DEFAULT_MAX_RECOVERY_ATTEMPTS * 2 { + for i := range _DEFAULT_MAX_RECOVERY_ATTEMPTS * 2 { recoveredHandles, err := recoverPendingWorkflows(context.Background(), []string{"local"}) if err != nil { t.Fatalf("failed to recover pending workflows on attempt %d: %v", i+1, err) @@ -1096,11 +1096,11 @@ func TestSendRecv(t *testing.T) { } // Get steps for both workflows and verify we have the expected number - sendSteps, err := getExecutor().systemDB.GetWorkflowSteps(context.Background(), handle.GetWorkflowID()) + sendSteps, err := dbos.systemDB.GetWorkflowSteps(context.Background(), handle.GetWorkflowID()) if err != nil { t.Fatalf("failed to get steps for send workflow: %v", err) } - receiveSteps, err := getExecutor().systemDB.GetWorkflowSteps(context.Background(), receiveHandle.GetWorkflowID()) + receiveSteps, err := dbos.systemDB.GetWorkflowSteps(context.Background(), receiveHandle.GetWorkflowID()) if err != nil { t.Fatalf("failed to get steps for receive workflow: %v", err) } From 7b4f8ee6461638d5e1b7614e0e84e8718a6fdd16 Mon Sep 17 00:00:00 2001 From: maxdml Date: Tue, 22 Jul 2025 15:24:43 -0700 Subject: [PATCH 2/4] do not expose singleton + mandatory userdb/appname --- dbos/admin-server_test.go | 42 ++++------ dbos/dbos.go | 167 ++++++++++++++++---------------------- dbos/logger_test.go | 24 +++--- dbos/system_database.go | 2 +- dbos/utils_test.go | 29 ++++--- dbos/workflows_test.go | 32 +------- 6 files changed, 125 insertions(+), 171 deletions(-) diff --git a/dbos/admin-server_test.go b/dbos/admin-server_test.go index b287aee8..8e582c08 100644 --- a/dbos/admin-server_test.go +++ b/dbos/admin-server_test.go @@ -5,40 +5,33 @@ import ( "encoding/json" "io" "net/http" - "os" "strings" "testing" "time" ) func TestAdminServer(t *testing.T) { - // Skip if database is not available - databaseURL := os.Getenv("DBOS_SYSTEM_DATABASE_URL") - if databaseURL == "" && os.Getenv("PGPASSWORD") == "" { - t.Skip("Database not available (DBOS_SYSTEM_DATABASE_URL and PGPASSWORD not set), skipping DBOS integration tests") - } + databaseURL := getDatabaseURL(t) - t.Run("Admin server is not started without WithAdminServer option", func(t *testing.T) { + t.Run("Admin server is not started by default", func(t *testing.T) { // Ensure clean state - if dbos != nil { - dbos.Shutdown() - } + Shutdown() - // Launch DBOS without admin server option - executor, err := NewExecutor() + err := Initialize(Config{ + DatabaseURL: databaseURL, + AppName: "test-app", + }) if err != nil { t.Skipf("Failed to create DBOS (database likely not available): %v", err) } - err = executor.Launch() + err = Launch() if err != nil { t.Skipf("Failed to launch DBOS (database likely not available): %v", err) } // Ensure cleanup defer func() { - if executor != nil { - executor.Shutdown() - } + Shutdown() }() // Give time for any startup processes @@ -62,26 +55,25 @@ func TestAdminServer(t *testing.T) { }) t.Run("Admin server endpoints", func(t *testing.T) { - // Ensure clean state - if dbos != nil { - dbos.Shutdown() - } + Shutdown() // Launch DBOS with admin server once for all endpoint tests - executor, err := NewExecutor(WithAdminServer()) + err := Initialize(Config{ + DatabaseURL: databaseURL, + AppName: "test-app", + AdminServer: true, + }) if err != nil { t.Skipf("Failed to create DBOS with admin server (database likely not available): %v", err) } - err = executor.Launch() + err = Launch() if err != nil { t.Skipf("Failed to launch DBOS with admin server (database likely not available): %v", err) } // Ensure cleanup defer func() { - if executor != nil { - executor.Shutdown() - } + Shutdown() }() // Give the server a moment to start diff --git a/dbos/dbos.go b/dbos/dbos.go index 04ac7344..c151bce4 100644 --- a/dbos/dbos.go +++ b/dbos/dbos.go @@ -64,103 +64,83 @@ func getLogger() *slog.Logger { return logger } -type config struct { - logger *slog.Logger - adminServer bool - databaseURL string - appName string +type Config struct { + DatabaseURL string + AppName string + Logger *slog.Logger + AdminServer bool } -// NewConfig merges configuration from two sources in order of precedence: +// ProcessConfig merges configuration from two sources in order of precedence: // 1. programmatic configuration // 2. environment variables // Finally, it applies default values if needed. -func NewConfig(programmaticConfig config) *config { - dbosConfig := &config{} +func ProcessConfig(inputConfig *Config) (*Config, error) { + // First check required fields + if len(inputConfig.DatabaseURL) == 0 { + return nil, fmt.Errorf("missing required config field: databaseURL") + } + if inputConfig.AppName == "" { + return nil, fmt.Errorf("missing required config field: appName") + } + + dbosConfig := &Config{} // Start with environment variables (lowest precedence) if dbURL := os.Getenv("DBOS_SYSTEM_DATABASE_URL"); dbURL != "" { - dbosConfig.databaseURL = dbURL + dbosConfig.DatabaseURL = dbURL } // Override with programmatic configuration (highest precedence) - if len(programmaticConfig.databaseURL) > 0 { - dbosConfig.databaseURL = programmaticConfig.databaseURL + if len(inputConfig.DatabaseURL) > 0 { + dbosConfig.DatabaseURL = inputConfig.DatabaseURL } - if len(programmaticConfig.appName) > 0 { - dbosConfig.appName = programmaticConfig.appName + if len(inputConfig.AppName) > 0 { + dbosConfig.AppName = inputConfig.AppName } // Copy over parameters that can only be set programmatically - dbosConfig.logger = programmaticConfig.logger - dbosConfig.adminServer = programmaticConfig.adminServer + dbosConfig.Logger = inputConfig.Logger + dbosConfig.AdminServer = inputConfig.AdminServer // Load defaults - if len(dbosConfig.databaseURL) == 0 { + if len(dbosConfig.DatabaseURL) == 0 { getLogger().Info("Using default database URL: postgres://postgres:${PGPASSWORD}@localhost:5432/dbos?sslmode=disable") password := url.QueryEscape(os.Getenv("PGPASSWORD")) - dbosConfig.databaseURL = fmt.Sprintf("postgres://postgres:%s@localhost:5432/dbos?sslmode=disable", password) + dbosConfig.DatabaseURL = fmt.Sprintf("postgres://postgres:%s@localhost:5432/dbos?sslmode=disable", password) + } + + if dbosConfig.Logger == nil { + dbosConfig.Logger = slog.New(slog.NewTextHandler(os.Stderr, nil)) } - return dbosConfig + + return dbosConfig, nil } -var dbos *Executor // DBOS singleton instance +var dbos *executor // DBOS singleton instance -type Executor struct { +type executor struct { systemDB SystemDatabase queueRunnerCtx context.Context queueRunnerCancelFunc context.CancelFunc queueRunnerDone chan struct{} adminServer *adminServer - config *config -} - -type executorOption func(*config) - -func WithLogger(logger *slog.Logger) executorOption { - return func(config *config) { - config.logger = logger - } -} - -func WithAdminServer() executorOption { - return func(config *config) { - config.adminServer = true - } -} - -func WithDatabaseURL(url string) executorOption { - return func(config *config) { - config.databaseURL = url - } + config *Config } -func WithAppName(name string) executorOption { - return func(config *config) { - config.appName = name - } -} - -func NewExecutor(options ...executorOption) (*Executor, error) { +func Initialize(inputConfig Config) error { if dbos != nil { fmt.Println("warning: DBOS instance already initialized, skipping re-initialization") - return nil, newInitializationError("DBOS already initialized") - } - - // Start with default configuration - config := &config{ - logger: slog.New(slog.NewTextHandler(os.Stderr, nil)), - } - - // Apply options - for _, option := range options { - option(config) + return newInitializationError("DBOS already initialized") } // Load & process the configuration - config = NewConfig(*config) + config, err := ProcessConfig(&inputConfig) + if err != nil { + return newInitializationError(err.Error()) + } // Set global logger - logger = config.logger + logger = config.Logger // Initialize global variables with environment variables, providing defaults if not set _APP_VERSION = os.Getenv("DBOS__APPVERSION") @@ -178,30 +158,30 @@ func NewExecutor(options ...executorOption) (*Executor, error) { _APP_ID = os.Getenv("DBOS__APPID") // Create the system database - systemDB, err := NewSystemDatabase(config.databaseURL) + systemDB, err := NewSystemDatabase(config.DatabaseURL) if err != nil { - return nil, newInitializationError(fmt.Sprintf("failed to create system database: %v", err)) + return newInitializationError(fmt.Sprintf("failed to create system database: %v", err)) } logger.Info("System database initialized") - // Create the executor instance - executor := &Executor{ + // Set the global dbos instance + dbos = &executor{ config: config, systemDB: systemDB, } - // Set the global dbos instance - dbos = executor - - return executor, nil + return nil } -func (e *Executor) Launch() error { +func Launch() error { + if dbos == nil { + return newInitializationError("DBOS instance not initialized, call Initialize first") + } // Start the system database - e.systemDB.Launch(context.Background()) + dbos.systemDB.Launch(context.Background()) // Start the admin server if configured - if e.config.adminServer { + if dbos.config.AdminServer { adminServer := newAdminServer(_DEFAULT_ADMIN_SERVER_PORT) err := adminServer.Start() if err != nil { @@ -209,18 +189,18 @@ func (e *Executor) Launch() error { return newInitializationError(fmt.Sprintf("failed to start admin server: %v", err)) } logger.Info("Admin server started", "port", _DEFAULT_ADMIN_SERVER_PORT) - e.adminServer = adminServer + dbos.adminServer = adminServer } // Create context with cancel function for queue runner ctx, cancel := context.WithCancel(context.Background()) - e.queueRunnerCtx = ctx - e.queueRunnerCancelFunc = cancel - e.queueRunnerDone = make(chan struct{}) + dbos.queueRunnerCtx = ctx + dbos.queueRunnerCancelFunc = cancel + dbos.queueRunnerDone = make(chan struct{}) // Start the queue runner in a goroutine go func() { - defer close(e.queueRunnerDone) + defer close(dbos.queueRunnerDone) queueRunner(ctx) }() logger.Info("Queue runner started") @@ -241,19 +221,19 @@ func (e *Executor) Launch() error { return nil } -func (e *Executor) Shutdown() { - if e == nil { - fmt.Println("Executor instance is nil, cannot shutdown") +func Shutdown() { + if dbos == nil { + fmt.Println("DBOS instance is nil, cannot shutdown") return } // XXX is there a way to ensure all workflows goroutine are done before closing? // Cancel the context to stop the queue runner - if e.queueRunnerCancelFunc != nil { - e.queueRunnerCancelFunc() + if dbos.queueRunnerCancelFunc != nil { + dbos.queueRunnerCancelFunc() // Wait for queue runner to finish - <-e.queueRunnerDone + <-dbos.queueRunnerDone getLogger().Info("Queue runner stopped") } @@ -271,26 +251,23 @@ func (e *Executor) Shutdown() { } } - if e.systemDB != nil { - e.systemDB.Shutdown() - e.systemDB = nil + if dbos.systemDB != nil { + dbos.systemDB.Shutdown() + dbos.systemDB = nil } - if e.adminServer != nil { - err := e.adminServer.Shutdown() + if dbos.adminServer != nil { + err := dbos.adminServer.Shutdown() if err != nil { getLogger().Error("Failed to shutdown admin server", "error", err) } else { getLogger().Info("Admin server shutdown complete") } - e.adminServer = nil + dbos.adminServer = nil } - // Clear global references if this is the global instance - if dbos == e { - if logger != nil { - logger = nil - } - dbos = nil + if logger != nil { + logger = nil } + dbos = nil } diff --git a/dbos/logger_test.go b/dbos/logger_test.go index b3c82df7..2b7eeebb 100644 --- a/dbos/logger_test.go +++ b/dbos/logger_test.go @@ -8,20 +8,22 @@ import ( ) func TestLogger(t *testing.T) { + databaseURL := getDatabaseURL(t) t.Run("Default logger", func(t *testing.T) { - executor, err := NewExecutor() // Create executor with default logger + err := Initialize(Config{ + DatabaseURL: databaseURL, + AppName: "test-app", + }) // Create executor with default logger if err != nil { t.Fatalf("Failed to create executor with default logger: %v", err) } - err = executor.Launch() + err = Launch() if err != nil { t.Fatalf("Failed to launch with default logger: %v", err) } t.Cleanup(func() { - if executor != nil { - executor.Shutdown() - } + Shutdown() }) if logger == nil { @@ -43,18 +45,20 @@ func TestLogger(t *testing.T) { // Add some context to the slog logger slogLogger = slogLogger.With("service", "dbos-test", "environment", "test") - executor, err := NewExecutor(WithLogger(slogLogger)) + err := Initialize(Config{ + DatabaseURL: databaseURL, + AppName: "test-app", + Logger: slogLogger, + }) if err != nil { t.Fatalf("Failed to create executor with custom logger: %v", err) } - err = executor.Launch() + err = Launch() if err != nil { t.Fatalf("Failed to launch with custom logger: %v", err) } t.Cleanup(func() { - if executor != nil { - executor.Shutdown() - } + Shutdown() }) if logger == nil { diff --git a/dbos/system_database.go b/dbos/system_database.go index d38ecdcc..a6ea0d6d 100644 --- a/dbos/system_database.go +++ b/dbos/system_database.go @@ -1029,7 +1029,7 @@ func (s *systemDatabase) Send(ctx context.Context, input WorkflowSendInput) erro defer tx.Rollback(ctx) // Check if operation was already executed and do nothing if so - checkInput := CheckOperationExecutionDBInput{ + checkInput := checkOperationExecutionDBInput{ workflowID: workflowState.WorkflowID, operationID: stepID, functionName: functionName, diff --git a/dbos/utils_test.go b/dbos/utils_test.go index 9a0c1293..f55596fb 100644 --- a/dbos/utils_test.go +++ b/dbos/utils_test.go @@ -12,15 +12,23 @@ import ( "github.com/jackc/pgx/v5" ) -/* Test database setup */ -func setupDBOS(t *testing.T) { - t.Helper() - +func getDatabaseURL(t *testing.T) string { databaseURL := os.Getenv("DBOS_SYSTEM_DATABASE_URL") if databaseURL == "" { + if os.Getenv("PGPASSWORD") == "" { + t.Skip("PGPASSWORD not set, cannot construct database URL") + } password := url.QueryEscape(os.Getenv("PGPASSWORD")) databaseURL = fmt.Sprintf("postgres://postgres:%s@localhost:5432/dbos?sslmode=disable", password) } + return databaseURL +} + +/* Test database setup */ +func setupDBOS(t *testing.T) { + t.Helper() + + databaseURL := getDatabaseURL(t) // Clean up the test database parsedURL, err := pgx.ParseConfig(databaseURL) @@ -46,12 +54,15 @@ func setupDBOS(t *testing.T) { t.Fatalf("failed to drop test database: %v", err) } - executor, err := NewExecutor() + err = Initialize(Config{ + DatabaseURL: databaseURL, + AppName: "test-app", + }) if err != nil { t.Fatalf("failed to create DBOS instance: %v", err) } - err = executor.Launch() + err = Launch() if err != nil { t.Fatalf("failed to launch DBOS instance: %v", err) } @@ -63,9 +74,7 @@ func setupDBOS(t *testing.T) { // Register cleanup to run after test completes t.Cleanup(func() { fmt.Println("Cleaning up DBOS instance...") - if executor != nil { - executor.Shutdown() - } + Shutdown() }) } @@ -122,7 +131,7 @@ func restartQueueRunner() { dbos.queueRunnerCtx = ctx dbos.queueRunnerCancelFunc = cancel dbos.queueRunnerDone = make(chan struct{}) - + // Start the queue runner in a goroutine go func() { defer close(dbos.queueRunnerDone) diff --git a/dbos/workflows_test.go b/dbos/workflows_test.go index ce02a45a..00c025ab 100644 --- a/dbos/workflows_test.go +++ b/dbos/workflows_test.go @@ -12,9 +12,7 @@ Test workflow and steps features import ( "context" - "encoding/hex" "fmt" - "maps" "strings" "testing" "time" @@ -115,32 +113,6 @@ var ( }) ) -func TestAppVersion(t *testing.T) { - if _, err := hex.DecodeString(_APP_VERSION); err != nil { - t.Fatalf("APP_VERSION is not a valid hex string: %v", err) - } - - // Save the original registry content - originalRegistry := make(map[string]workflowRegistryEntry) - maps.Copy(originalRegistry, registry) - - // Restore the registry after the test - defer func() { - registry = originalRegistry - }() - - // Replace the registry and verify the hash is different - registry = make(map[string]workflowRegistryEntry) - - WithWorkflow(func(ctx context.Context, input string) (string, error) { - return "new-registry-workflow-" + input, nil - }) - hash2 := computeApplicationVersion() - if _APP_VERSION == hash2 { - t.Fatalf("APP_VERSION hash did not change after replacing registry") - } -} - func TestWorkflowsWrapping(t *testing.T) { setupDBOS(t) @@ -1311,14 +1283,14 @@ func TestSendRecv(t *testing.T) { if len(recoveredHandles) != 2 { t.Fatalf("expected 2 recovered handles, got %d", len(recoveredHandles)) } - steps, err := getExecutor().systemDB.GetWorkflowSteps(context.Background(), sendHandle.GetWorkflowID()) + steps, err := dbos.systemDB.GetWorkflowSteps(context.Background(), sendHandle.GetWorkflowID()) if err != nil { t.Fatalf("failed to get steps for send idempotency workflow: %v", err) } if len(steps) != 1 { t.Fatalf("expected 1 step in send idempotency workflow, got %d", len(steps)) } - steps, err = getExecutor().systemDB.GetWorkflowSteps(context.Background(), receiveHandle.GetWorkflowID()) + steps, err = dbos.systemDB.GetWorkflowSteps(context.Background(), receiveHandle.GetWorkflowID()) if err != nil { t.Fatalf("failed to get steps for receive idempotency workflow: %v", err) } From 10cea29dcea79b522538d5e10e17563f59053491 Mon Sep 17 00:00:00 2001 From: maxdml Date: Tue, 22 Jul 2025 15:30:12 -0700 Subject: [PATCH 3/4] rename --- dbos/{admin-server.go => admin_server.go} | 0 dbos/{admin-server_test.go => admin_server_test.go} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename dbos/{admin-server.go => admin_server.go} (100%) rename dbos/{admin-server_test.go => admin_server_test.go} (100%) diff --git a/dbos/admin-server.go b/dbos/admin_server.go similarity index 100% rename from dbos/admin-server.go rename to dbos/admin_server.go diff --git a/dbos/admin-server_test.go b/dbos/admin_server_test.go similarity index 100% rename from dbos/admin-server_test.go rename to dbos/admin_server_test.go From 0b6709a988eedc74fce5496aaf92b0d38dea35a6 Mon Sep 17 00:00:00 2001 From: maxdml Date: Tue, 22 Jul 2025 15:31:57 -0700 Subject: [PATCH 4/4] nits --- dbos/admin_server_test.go | 8 ++++---- dbos/dbos.go | 11 +++++------ 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/dbos/admin_server_test.go b/dbos/admin_server_test.go index 8e582c08..9f971ffa 100644 --- a/dbos/admin_server_test.go +++ b/dbos/admin_server_test.go @@ -22,11 +22,11 @@ func TestAdminServer(t *testing.T) { AppName: "test-app", }) if err != nil { - t.Skipf("Failed to create DBOS (database likely not available): %v", err) + t.Skipf("Failed to initialize DBOS: %v", err) } err = Launch() if err != nil { - t.Skipf("Failed to launch DBOS (database likely not available): %v", err) + t.Skipf("Failed to initialize DBOS: %v", err) } // Ensure cleanup @@ -64,11 +64,11 @@ func TestAdminServer(t *testing.T) { AdminServer: true, }) if err != nil { - t.Skipf("Failed to create DBOS with admin server (database likely not available): %v", err) + t.Skipf("Failed to initialize DBOS with admin server: %v", err) } err = Launch() if err != nil { - t.Skipf("Failed to launch DBOS with admin server (database likely not available): %v", err) + t.Skipf("Failed to initialize DBOS with admin server: %v", err) } // Ensure cleanup diff --git a/dbos/dbos.go b/dbos/dbos.go index c151bce4..f56fd54d 100644 --- a/dbos/dbos.go +++ b/dbos/dbos.go @@ -80,7 +80,7 @@ func ProcessConfig(inputConfig *Config) (*Config, error) { if len(inputConfig.DatabaseURL) == 0 { return nil, fmt.Errorf("missing required config field: databaseURL") } - if inputConfig.AppName == "" { + if len(inputConfig.AppName) == 0 { return nil, fmt.Errorf("missing required config field: appName") } @@ -103,16 +103,15 @@ func ProcessConfig(inputConfig *Config) (*Config, error) { dbosConfig.AdminServer = inputConfig.AdminServer // Load defaults + if dbosConfig.Logger == nil { + dbosConfig.Logger = slog.New(slog.NewTextHandler(os.Stderr, nil)) + } if len(dbosConfig.DatabaseURL) == 0 { - getLogger().Info("Using default database URL: postgres://postgres:${PGPASSWORD}@localhost:5432/dbos?sslmode=disable") + dbosConfig.Logger.Info("Using default database URL: postgres://postgres:${PGPASSWORD}@localhost:5432/dbos?sslmode=disable") password := url.QueryEscape(os.Getenv("PGPASSWORD")) dbosConfig.DatabaseURL = fmt.Sprintf("postgres://postgres:%s@localhost:5432/dbos?sslmode=disable", password) } - if dbosConfig.Logger == nil { - dbosConfig.Logger = slog.New(slog.NewTextHandler(os.Stderr, nil)) - } - return dbosConfig, nil }