diff --git a/.github/workflows/security.yml b/.github/workflows/security.yml new file mode 100644 index 00000000..60578a2c --- /dev/null +++ b/.github/workflows/security.yml @@ -0,0 +1,52 @@ +name: Security Checks + +on: + push: + branches: + - main + pull_request: + branches: + types: + - ready_for_review + - opened + - reopened + - synchronize + workflow_dispatch: + +jobs: + security: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version: '1.25.x' + + - name: Cache Go modules + uses: actions/cache@v4 + with: + path: | + ~/go/pkg/mod + ~/.cache/go-build + key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-go- + + - name: Download dependencies + run: go mod download + + - name: Install security tools + run: | + go install golang.org/x/vuln/cmd/govulncheck@latest + go install github.com/securego/gosec/v2/cmd/gosec@latest + + - name: Run govulncheck + run: govulncheck ./... + + - name: Run gosec + run: gosec ./... \ No newline at end of file diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 164d25a4..9c312c5a 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -42,7 +42,7 @@ jobs: - name: Setup Go uses: actions/setup-go@v5 with: - go-version: '1.23.x' + go-version: '1.25.x' - name: Cache Go modules uses: actions/cache@v4 @@ -61,7 +61,7 @@ jobs: run: go install gotest.tools/gotestsum@latest - name: Run tests - run: gotestsum --format github-action -- -race ./... + run: go vet ./... && gotestsum --format github-action -- -race ./... working-directory: ./dbos env: PGPASSWORD: a!b@c$d()e*_,/:;=?@ff[]22 diff --git a/README.md b/README.md index 3d030d77..b128bda2 100644 --- a/README.md +++ b/README.md @@ -3,11 +3,12 @@ [![Go Reference](https://pkg.go.dev/badge/github.com/dbos-inc/dbos-transact-go.svg)](https://pkg.go.dev/github.com/dbos-inc/dbos-transact-go) [![Go Report Card](https://goreportcard.com/badge/github.com/dbos-inc/dbos-transact-go)](https://goreportcard.com/report/github.com/dbos-inc/dbos-transact-go) [![GitHub release (latest SemVer)](https://img.shields.io/github/v/release/dbos-inc/dbos-transact-go?sort=semver)](https://github.com/dbos-inc/dbos-transact-go/releases) +[![Join Discord](https://img.shields.io/badge/Discord-Join%20Chat-5865F2?logo=discord&logoColor=white)](https://discord.com/invite/jsmC6pXGgX) # DBOS Transact: Lightweight Durable Workflow Orchestration with Postgres -#### [Documentation](https://docs.dbos.dev/)   •   [Examples](https://docs.dbos.dev/examples)   •   [Github](https://github.com/dbos-inc)   •   [Discord](https://discord.com/invite/jsmC6pXGgX) +#### [Documentation](https://docs.dbos.dev/)   •   [Examples](https://docs.dbos.dev/examples)   •   [Github](https://github.com/dbos-inc) #### This Golang version of DBOS Transact is in Alpha! diff --git a/dbos/admin_server.go b/dbos/admin_server.go index 1c81d41a..5eb97a6a 100644 --- a/dbos/admin_server.go +++ b/dbos/admin_server.go @@ -10,28 +10,37 @@ import ( ) const ( - healthCheckPath = "/dbos-healthz" - workflowRecoveryPath = "/dbos-workflow-recovery" - workflowQueuesMetadataPath = "/dbos-workflow-queues-metadata" + _HEALTHCHECK_PATH = "/dbos-healthz" + _WORKFLOW_RECOVERY_PATH = "/dbos-workflow-recovery" + _WORKFLOW_QUEUES_METADATA_PATH = "/dbos-workflow-queues-metadata" + + _ADMIN_SERVER_READ_HEADER_TIMEOUT = 5 * time.Second + _ADMIN_SERVER_SHUTDOWN_TIMEOUT = 10 * time.Second ) type adminServer struct { server *http.Server logger *slog.Logger + port int } func newAdminServer(ctx *dbosContext, port int) *adminServer { mux := http.NewServeMux() // Health endpoint - mux.HandleFunc(healthCheckPath, func(w http.ResponseWriter, r *http.Request) { + mux.HandleFunc(_HEALTHCHECK_PATH, func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - w.Write([]byte(`{"status":"healthy"}`)) + _, err := w.Write([]byte(`{"status":"healthy"}`)) + if err != nil { + ctx.logger.Error("Error writing health check response", "error", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } }) // Recovery endpoint - mux.HandleFunc(workflowRecoveryPath, func(w http.ResponseWriter, r *http.Request) { + mux.HandleFunc(_WORKFLOW_RECOVERY_PATH, func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return @@ -67,7 +76,7 @@ func newAdminServer(ctx *dbosContext, port int) *adminServer { }) // Queue metadata endpoint - mux.HandleFunc(workflowQueuesMetadataPath, func(w http.ResponseWriter, r *http.Request) { + mux.HandleFunc(_WORKFLOW_QUEUES_METADATA_PATH, func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return @@ -84,18 +93,20 @@ func newAdminServer(ctx *dbosContext, port int) *adminServer { }) server := &http.Server{ - Addr: fmt.Sprintf(":%d", port), - Handler: mux, + Addr: fmt.Sprintf(":%d", port), + Handler: mux, + ReadHeaderTimeout: _ADMIN_SERVER_READ_HEADER_TIMEOUT, } return &adminServer{ server: server, logger: ctx.logger, + port: port, } } func (as *adminServer) Start() error { - as.logger.Info("Starting admin server", "port", 3001) + as.logger.Info("Starting admin server", "port", as.port) go func() { if err := as.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { @@ -109,8 +120,8 @@ func (as *adminServer) Start() error { func (as *adminServer) Shutdown(ctx context.Context) error { as.logger.Info("Shutting down admin server") - // XXX consider moving the grace period to DBOSContext.Shutdown() - ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + // Note: consider moving the grace period to DBOSContext.Shutdown() + ctx, cancel := context.WithTimeout(ctx, _ADMIN_SERVER_SHUTDOWN_TIMEOUT) defer cancel() if err := as.server.Shutdown(ctx); err != nil { diff --git a/dbos/admin_server_test.go b/dbos/admin_server_test.go index 2ec425d9..8d6fe2ff 100644 --- a/dbos/admin_server_test.go +++ b/dbos/admin_server_test.go @@ -38,7 +38,7 @@ func TestAdminServer(t *testing.T) { // Verify admin server is not running client := &http.Client{Timeout: 1 * time.Second} - _, err = client.Get("http://localhost:3001" + healthCheckPath) + _, err = client.Get("http://localhost:3001" + _HEALTHCHECK_PATH) require.Error(t, err, "Expected request to fail when admin server is not started") // Verify the DBOS executor doesn't have an admin server instance @@ -89,13 +89,13 @@ func TestAdminServer(t *testing.T) { { name: "Health endpoint responds correctly", method: "GET", - endpoint: "http://localhost:3001" + healthCheckPath, + endpoint: "http://localhost:3001" + _HEALTHCHECK_PATH, expectedStatus: http.StatusOK, }, { name: "Recovery endpoint responds correctly with valid JSON", method: "POST", - endpoint: "http://localhost:3001" + workflowRecoveryPath, + endpoint: "http://localhost:3001" + _WORKFLOW_RECOVERY_PATH, body: bytes.NewBuffer(mustMarshal([]string{"executor1", "executor2"})), contentType: "application/json", expectedStatus: http.StatusOK, @@ -109,13 +109,13 @@ func TestAdminServer(t *testing.T) { { name: "Recovery endpoint rejects invalid methods", method: "GET", - endpoint: "http://localhost:3001" + workflowRecoveryPath, + endpoint: "http://localhost:3001" + _WORKFLOW_RECOVERY_PATH, expectedStatus: http.StatusMethodNotAllowed, }, { name: "Recovery endpoint rejects invalid JSON", method: "POST", - endpoint: "http://localhost:3001" + workflowRecoveryPath, + endpoint: "http://localhost:3001" + _WORKFLOW_RECOVERY_PATH, body: strings.NewReader(`{"invalid": json}`), contentType: "application/json", expectedStatus: http.StatusBadRequest, @@ -123,7 +123,7 @@ func TestAdminServer(t *testing.T) { { name: "Queue metadata endpoint responds correctly", method: "GET", - endpoint: "http://localhost:3001" + workflowQueuesMetadataPath, + endpoint: "http://localhost:3001" + _WORKFLOW_QUEUES_METADATA_PATH, expectedStatus: http.StatusOK, validateResp: func(t *testing.T, resp *http.Response) { var queueMetadata []WorkflowQueue @@ -149,7 +149,7 @@ func TestAdminServer(t *testing.T) { { name: "Queue metadata endpoint rejects invalid methods", method: "POST", - endpoint: "http://localhost:3001" + workflowQueuesMetadataPath, + endpoint: "http://localhost:3001" + _WORKFLOW_QUEUES_METADATA_PATH, expectedStatus: http.StatusMethodNotAllowed, }, } diff --git a/dbos/client_test.go b/dbos/client_test.go index 733c4937..d6dfb4a2 100644 --- a/dbos/client_test.go +++ b/dbos/client_test.go @@ -473,10 +473,6 @@ func TestCancelResume(t *testing.T) { // Verify the deadline was reset (should be different from original) assert.False(t, resumeStatus.Deadline.Equal(originalDeadline), "expected deadline to be reset after resume, but it remained the same: %v", originalDeadline) - // The new deadline should be after resumeStart + workflowTimeout - expectedDeadline := resumeStart.Add(workflowTimeout - 100*time.Millisecond) // Allow some leeway for processing time - assert.True(t, resumeStatus.Deadline.After(expectedDeadline), "deadline %v is too early (expected around %v)", resumeStatus.Deadline, expectedDeadline) - // Wait for the workflow to complete _, err = resumeHandle.GetResult() require.Error(t, err, "expected timeout error, but got none") @@ -491,6 +487,10 @@ func TestCancelResume(t *testing.T) { finalStatus, err := resumeHandle.GetStatus() require.NoError(t, err, "failed to get final workflow status") + // The new deadline should have been set after resumeStart + workflowTimeout + expectedDeadline := resumeStart.Add(workflowTimeout - 100*time.Millisecond) // Allow some leeway for processing time + assert.True(t, finalStatus.Deadline.After(expectedDeadline), "deadline %v is too early (expected around %v)", resumeStatus.Deadline, expectedDeadline) + assert.Equal(t, WorkflowStatusCancelled, finalStatus.Status, "expected final workflow status to be CANCELLED") require.True(t, queueEntriesAreCleanedUp(serverCtx), "expected queue entries to be cleaned up after cancel/resume timeout test") diff --git a/dbos/dbos.go b/dbos/dbos.go index 229480ed..a7838177 100644 --- a/dbos/dbos.go +++ b/dbos/dbos.go @@ -15,6 +15,7 @@ import ( "io" "log/slog" "os" + "path/filepath" "sync" "sync/atomic" "time" @@ -119,10 +120,9 @@ type dbosContext struct { // Wait group for workflow goroutines workflowsWg *sync.WaitGroup - // Workflow registry - workflowRegistry map[string]workflowRegistryEntry - workflowRegMutex *sync.RWMutex - workflowCustomNametoFQN sync.Map // Maps fully qualified workflow names to custom names. Usefor when client enqueues a workflow by name because registry is indexed by FQN. + // Workflow registry - read-mostly sync.Map since registration happens only before launch + workflowRegistry *sync.Map // map[string]workflowRegistryEntry + workflowCustomNametoFQN *sync.Map // Maps fully qualified workflow names to custom names. Usefor when client enqueues a workflow by name because registry is indexed by FQN. // Workflow scheduler workflowScheduler *cron.Cron @@ -158,15 +158,15 @@ func WithValue(ctx DBOSContext, key, val any) DBOSContext { // Will do nothing if the concrete type is not dbosContext if dbosCtx, ok := ctx.(*dbosContext); ok { return &dbosContext{ - ctx: context.WithValue(dbosCtx.ctx, key, val), // Spawn a new child context with the value set - logger: dbosCtx.logger, - systemDB: dbosCtx.systemDB, - workflowsWg: dbosCtx.workflowsWg, - workflowRegistry: dbosCtx.workflowRegistry, - workflowRegMutex: dbosCtx.workflowRegMutex, - applicationVersion: dbosCtx.applicationVersion, - executorID: dbosCtx.executorID, - applicationID: dbosCtx.applicationID, + ctx: context.WithValue(dbosCtx.ctx, key, val), // Spawn a new child context with the value set + logger: dbosCtx.logger, + systemDB: dbosCtx.systemDB, + workflowsWg: dbosCtx.workflowsWg, + workflowRegistry: dbosCtx.workflowRegistry, + workflowCustomNametoFQN: dbosCtx.workflowCustomNametoFQN, + applicationVersion: dbosCtx.applicationVersion, + executorID: dbosCtx.executorID, + applicationID: dbosCtx.applicationID, } } return nil @@ -181,15 +181,15 @@ func WithoutCancel(ctx DBOSContext) DBOSContext { } if dbosCtx, ok := ctx.(*dbosContext); ok { return &dbosContext{ - ctx: context.WithoutCancel(dbosCtx.ctx), - logger: dbosCtx.logger, - systemDB: dbosCtx.systemDB, - workflowsWg: dbosCtx.workflowsWg, - workflowRegistry: dbosCtx.workflowRegistry, - workflowRegMutex: dbosCtx.workflowRegMutex, - applicationVersion: dbosCtx.applicationVersion, - executorID: dbosCtx.executorID, - applicationID: dbosCtx.applicationID, + ctx: context.WithoutCancel(dbosCtx.ctx), + logger: dbosCtx.logger, + systemDB: dbosCtx.systemDB, + workflowsWg: dbosCtx.workflowsWg, + workflowRegistry: dbosCtx.workflowRegistry, + workflowCustomNametoFQN: dbosCtx.workflowCustomNametoFQN, + applicationVersion: dbosCtx.applicationVersion, + executorID: dbosCtx.executorID, + applicationID: dbosCtx.applicationID, } } return nil @@ -205,15 +205,15 @@ func WithTimeout(ctx DBOSContext, timeout time.Duration) (DBOSContext, context.C if dbosCtx, ok := ctx.(*dbosContext); ok { newCtx, cancelFunc := context.WithTimeoutCause(dbosCtx.ctx, timeout, errors.New("DBOS context timeout")) return &dbosContext{ - ctx: newCtx, - logger: dbosCtx.logger, - systemDB: dbosCtx.systemDB, - workflowsWg: dbosCtx.workflowsWg, - workflowRegistry: dbosCtx.workflowRegistry, - workflowRegMutex: dbosCtx.workflowRegMutex, - applicationVersion: dbosCtx.applicationVersion, - executorID: dbosCtx.executorID, - applicationID: dbosCtx.applicationID, + ctx: newCtx, + logger: dbosCtx.logger, + systemDB: dbosCtx.systemDB, + workflowsWg: dbosCtx.workflowsWg, + workflowRegistry: dbosCtx.workflowRegistry, + workflowCustomNametoFQN: dbosCtx.workflowCustomNametoFQN, + applicationVersion: dbosCtx.applicationVersion, + executorID: dbosCtx.executorID, + applicationID: dbosCtx.applicationID, }, cancelFunc } return nil, func() {} @@ -261,11 +261,11 @@ func (c *dbosContext) GetApplicationID() string { func NewDBOSContext(inputConfig Config) (DBOSContext, error) { ctx, cancelFunc := context.WithCancelCause(context.Background()) initExecutor := &dbosContext{ - workflowsWg: &sync.WaitGroup{}, - ctx: ctx, - ctxCancelFunc: cancelFunc, - workflowRegistry: make(map[string]workflowRegistryEntry), - workflowRegMutex: &sync.RWMutex{}, + workflowsWg: &sync.WaitGroup{}, + ctx: ctx, + ctxCancelFunc: cancelFunc, + workflowRegistry: &sync.Map{}, + workflowCustomNametoFQN: &sync.Map{}, } // Load and process the configuration @@ -438,7 +438,20 @@ func getBinaryHash() (string, error) { return "", err } - file, err := os.Open(execPath) + execPath, err = filepath.EvalSymlinks(execPath) + if err != nil { + return "", fmt.Errorf("resolve self path: %w", err) + } + + fi, err := os.Lstat(execPath) + if err != nil { + return "", err + } + if !fi.Mode().IsRegular() { + return "", fmt.Errorf("executable is not a regular file") + } + + file, err := os.Open(execPath) // #nosec G304 -- opening our own executable, not user-supplied if err != nil { return "", err } diff --git a/dbos/errors.go b/dbos/errors.go index b71c3780..1fb9b931 100644 --- a/dbos/errors.go +++ b/dbos/errors.go @@ -8,7 +8,6 @@ type DBOSErrorCode int const ( ConflictingIDError DBOSErrorCode = iota + 1 // Workflow ID conflicts or duplicate operations InitializationError // DBOS context initialization failures - WorkflowFunctionNotFound // Workflow function not registered NonExistentWorkflowError // Referenced workflow does not exist ConflictingWorkflowError // Workflow with same ID already exists with different parameters WorkflowCancelled // Workflow was cancelled during execution @@ -68,18 +67,6 @@ func newInitializationError(message string) *DBOSError { } } -func newWorkflowFunctionNotFoundError(workflowID, message string) *DBOSError { - msg := fmt.Sprintf("Workflow function not found for workflow ID %s", workflowID) - if message != "" { - msg += ": " + message - } - return &DBOSError{ - Message: msg, - Code: WorkflowFunctionNotFound, - WorkflowID: workflowID, - } -} - func newNonExistentWorkflowError(workflowID string) *DBOSError { return &DBOSError{ Message: fmt.Sprintf("workflow %s does not exist", workflowID), diff --git a/dbos/queue.go b/dbos/queue.go index e56f123b..709dd503 100644 --- a/dbos/queue.go +++ b/dbos/queue.go @@ -145,7 +145,7 @@ type queueRunner struct { workflowQueueRegistry map[string]WorkflowQueue // Channel to signal completion back to the DBOS context - completionChan chan bool + completionChan chan struct{} } func newQueueRunner() *queueRunner { @@ -158,7 +158,7 @@ func newQueueRunner() *queueRunner { jitterMin: 0.95, jitterMax: 1.05, workflowQueueRegistry: make(map[string]WorkflowQueue), - completionChan: make(chan bool), + completionChan: make(chan struct{}), } } @@ -179,7 +179,11 @@ func (qr *queueRunner) run(ctx *dbosContext) { // Iterate through all queues in the registry for queueName, queue := range qr.workflowQueueRegistry { // Call DequeueWorkflows for each queue - dequeuedWorkflows, err := ctx.systemDB.dequeueWorkflows(ctx, queue, ctx.executorID, ctx.applicationVersion) + dequeuedWorkflows, err := ctx.systemDB.dequeueWorkflows(ctx, dequeueWorkflowsInput{ + queue: queue, + executorID: ctx.executorID, + applicationVersion: ctx.applicationVersion, + }) if err != nil { if pgErr, ok := err.(*pgconn.PgError); ok { switch pgErr.Code { @@ -206,11 +210,16 @@ func (qr *queueRunner) run(ctx *dbosContext) { continue } - registeredWorkflow, exists := ctx.workflowRegistry[wfName.(string)] + registeredWorkflowAny, exists := ctx.workflowRegistry.Load(wfName.(string)) if !exists { ctx.logger.Error("workflow function not found in registry", "workflow_name", workflow.name) continue } + registeredWorkflow, ok := registeredWorkflowAny.(workflowRegistryEntry) + if !ok { + ctx.logger.Error("invalid workflow registry entry type", "workflow_name", workflow.name) + continue + } // Deserialize input var input any @@ -245,14 +254,14 @@ func (qr *queueRunner) run(ctx *dbosContext) { } // Apply jitter to the polling interval - jitter := qr.jitterMin + rand.Float64()*(qr.jitterMax-qr.jitterMin) + jitter := qr.jitterMin + rand.Float64()*(qr.jitterMax-qr.jitterMin) // #nosec G404 -- non-crypto jitter; acceptable sleepDuration := time.Duration(pollingInterval * jitter * float64(time.Second)) // Sleep with jittered interval, but allow early exit on context cancellation select { case <-ctx.Done(): ctx.logger.Info("Queue runner stopping due to context cancellation", "cause", context.Cause(ctx)) - qr.completionChan <- true + qr.completionChan <- struct{}{} return case <-time.After(sleepDuration): // Continue to next iteration diff --git a/dbos/queues_test.go b/dbos/queues_test.go index f49fd82e..1fccffc9 100644 --- a/dbos/queues_test.go +++ b/dbos/queues_test.go @@ -439,7 +439,7 @@ func TestQueueRecovery(t *testing.T) { require.True(t, queueEntriesAreCleanedUp(dbosCtx), "expected queue entries to be cleaned up after global concurrency test") } -// TODO: we can update this test to have the same logic than TestWorkerConcurrency +// Note: we could update this test to have the same logic than TestWorkerConcurrency func TestGlobalConcurrency(t *testing.T) { dbosCtx := setupDBOS(t, true, true) @@ -850,6 +850,18 @@ func TestQueueTimeouts(t *testing.T) { RegisterWorkflow(dbosCtx, detachedWorkflow) RegisterWorkflow(dbosCtx, enqueuedWorkflowEnqueuesADetachedWorkflow) + timeoutOnDequeueQueue := NewWorkflowQueue(dbosCtx, "timeout-on-dequeue-queue", WithGlobalConcurrency(1)) + blockingEvent := NewEvent() + blockingWorkflow := func(ctx DBOSContext, _ string) (string, error) { + blockingEvent.Wait() + return "blocking-done", nil + } + RegisterWorkflow(dbosCtx, blockingWorkflow) + fastWorkflow := func(ctx DBOSContext, _ string) (string, error) { + return "done", nil + } + RegisterWorkflow(dbosCtx, fastWorkflow) + dbosCtx.Launch() t.Run("EnqueueWorkflowTimeout", func(t *testing.T) { @@ -936,6 +948,46 @@ func TestQueueTimeouts(t *testing.T) { require.True(t, queueEntriesAreCleanedUp(dbosCtx), "expected queue entries to be cleaned up after workflow cancellation, but they are not") }) + + t.Run("TimeoutOnlySetOnDequeue", func(t *testing.T) { + // Test that deadline is only set when workflow is dequeued, not when enqueued + + // Enqueue blocking workflow first + blockingHandle, err := RunAsWorkflow(dbosCtx, blockingWorkflow, "blocking", WithQueue(timeoutOnDequeueQueue.Name)) + require.NoError(t, err, "failed to enqueue blocking workflow") + + // Set a timeout that would expire if set on enqueue + timeout := 2 * time.Second + timeoutCtx, cancelFunc := WithTimeout(dbosCtx, timeout) + defer cancelFunc() + + // Enqueue second workflow with timeout + handle, err := RunAsWorkflow(timeoutCtx, fastWorkflow, "timeout-test", WithQueue(timeoutOnDequeueQueue.Name)) + require.NoError(t, err, "failed to enqueue timeout workflow") + + // Sleep for duration exceeding the timeout + time.Sleep(timeout * 2) + + // Signal the blocking workflow to complete + blockingEvent.Set() + + // Wait for blocking workflow to complete + blockingResult, err := blockingHandle.GetResult() + require.NoError(t, err, "failed to get result from blocking workflow") + assert.Equal(t, "blocking-done", blockingResult, "expected blocking workflow result") + + // Now the second workflow should dequeue and complete successfully (timeout should be much longer than execution time) + // Note: this might be flaky if we the dequeue is delayed too long + _, err = handle.GetResult() + require.NoError(t, err, "unexpected error from workflow") + + // Check the workflow status: should be success + finalStatus, err := handle.GetStatus() + require.NoError(t, err, "failed to get final status of timeout workflow") + assert.Equal(t, WorkflowStatusSuccess, finalStatus.Status, "expected timeout workflow status to be WorkflowStatusSuccess") + + require.True(t, queueEntriesAreCleanedUp(dbosCtx), "expected queue entries to be cleaned up after test") + }) } func TestPriorityQueue(t *testing.T) { diff --git a/dbos/recovery.go b/dbos/recovery.go index 47c3f0f7..3c75b2de 100644 --- a/dbos/recovery.go +++ b/dbos/recovery.go @@ -32,7 +32,7 @@ func recoverPendingWorkflows(ctx *dbosContext, executorIDs []string) ([]Workflow continue } if cleared { - workflowHandles = append(workflowHandles, &workflowPollingHandle[any]{workflowID: workflow.ID, dbosContext: ctx}) + workflowHandles = append(workflowHandles, newWorkflowPollingHandle[any](ctx, workflow.ID)) } continue } @@ -43,11 +43,16 @@ func recoverPendingWorkflows(ctx *dbosContext, executorIDs []string) ([]Workflow continue } - registeredWorkflow, exists := ctx.workflowRegistry[wfName.(string)] + registeredWorkflowAny, exists := ctx.workflowRegistry.Load(wfName.(string)) if !exists { ctx.logger.Error("Workflow function not found in registry", "workflow_id", workflow.ID, "name", workflow.Name) continue } + registeredWorkflow, ok := registeredWorkflowAny.(workflowRegistryEntry) + if !ok { + ctx.logger.Error("invalid workflow registry entry type", "workflow_id", workflow.ID, "name", workflow.Name) + continue + } // Convert workflow parameters to options opts := []WorkflowOption{ diff --git a/dbos/system_database.go b/dbos/system_database.go index 01e59ce8..c72e2a72 100644 --- a/dbos/system_database.go +++ b/dbos/system_database.go @@ -58,14 +58,14 @@ type systemDatabase interface { sleep(ctx context.Context, duration time.Duration) (time.Duration, error) // Queues - dequeueWorkflows(ctx context.Context, queue WorkflowQueue, executorID, applicationVersion string) ([]dequeuedWorkflow, error) + dequeueWorkflows(ctx context.Context, input dequeueWorkflowsInput) ([]dequeuedWorkflow, error) clearQueueAssignment(ctx context.Context, workflowID string) (bool, error) } type sysDB struct { pool *pgxpool.Pool notificationListenerConnection *pgconn.PgConn - notificationLoopDone chan bool + notificationLoopDone chan struct{} notificationsMap *sync.Map logger *slog.Logger launched bool @@ -104,7 +104,6 @@ func createDatabaseIfNotExists(ctx context.Context, databaseURL string, logger * return newInitializationError(fmt.Sprintf("failed to check if database exists: %v", err)) } if !exists { - // TODO: validate db name createSQL := fmt.Sprintf("CREATE DATABASE %s", pgx.Identifier{dbName}.Sanitize()) _, err = conn.Exec(ctx, createSQL) if err != nil { @@ -119,7 +118,21 @@ func createDatabaseIfNotExists(ctx context.Context, databaseURL string, logger * //go:embed migrations/*.sql var migrationFiles embed.FS -const _DBOS_MIGRATION_TABLE = "dbos_schema_migrations" +const ( + _DBOS_MIGRATION_TABLE = "dbos_schema_migrations" + + // PostgreSQL error codes + _PG_ERROR_UNIQUE_VIOLATION = "23505" + _PG_ERROR_FOREIGN_KEY_VIOLATION = "23503" + + // Notification channels + _DBOS_NOTIFICATIONS_CHANNEL = "dbos_notifications_channel" + _DBOS_WORKFLOW_EVENTS_CHANNEL = "dbos_workflow_events_channel" + + // Database retry timeouts + _DB_CONNECTION_RETRY_DELAY = 500 * time.Millisecond + _DB_RETRY_INTERVAL = 1 * time.Second +) func runMigrations(databaseURL string) error { // Change the driver to pgx5 @@ -194,7 +207,7 @@ func newSystemDatabase(ctx context.Context, databaseURL string, logger *slog.Log return nil, fmt.Errorf("failed to parse database URL: %v", err) } config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) { - if n.Channel == "dbos_notifications_channel" || n.Channel == "dbos_workflow_events_channel" { + if n.Channel == _DBOS_NOTIFICATIONS_CHANNEL || n.Channel == _DBOS_WORKFLOW_EVENTS_CHANNEL { // Check if an entry exists in the map, indexed by the payload // If yes, broadcast on the condition variable so listeners can wake up if cond, exists := notificationsMap.Load(n.Payload); exists { @@ -211,7 +224,7 @@ func newSystemDatabase(ctx context.Context, databaseURL string, logger *slog.Log pool: pool, notificationListenerConnection: notificationListenerConnection, notificationsMap: notificationsMap, - notificationLoopDone: make(chan bool), + notificationLoopDone: make(chan struct{}), logger: logger, }, nil } @@ -230,7 +243,10 @@ func (s *sysDB) shutdown(ctx context.Context) { // Context wasn't cancelled, let's manually close if !errors.Is(ctx.Err(), context.Canceled) { - s.notificationListenerConnection.Close(ctx) + err := s.notificationListenerConnection.Close(ctx) + if err != nil { + s.logger.Error("Failed to close notification listener connection", "error", err) + } } if s.launched { @@ -242,7 +258,7 @@ func (s *sysDB) shutdown(ctx context.Context) { // Allow pgx health checks to complete // https://github.com/jackc/pgx/blob/15bca4a4e14e0049777c1245dba4c16300fe4fd0/pgxpool/pool.go#L417 // These trigger go-leak alerts - time.Sleep(500 * time.Millisecond) + time.Sleep(_DB_CONNECTION_RETRY_DELAY) s.launched = false } @@ -290,7 +306,7 @@ func (s *sysDB) insertWorkflowStatus(ctx context.Context, input insertWorkflowSt var timeoutMs *int64 = nil if input.status.Timeout > 0 { - millis := input.status.Timeout.Milliseconds() + millis := input.status.Timeout.Round(time.Millisecond).Milliseconds() timeoutMs = &millis } @@ -377,7 +393,7 @@ func (s *sysDB) insertWorkflowStatus(ctx context.Context, input insertWorkflowSt ) if err != nil { // Handle unique constraint violation for the deduplication ID (this should be the only case for a 23505) - if pgErr, ok := err.(*pgconn.PgError); ok && pgErr.Code == "23505" { + if pgErr, ok := err.(*pgconn.PgError); ok && pgErr.Code == _PG_ERROR_UNIQUE_VIOLATION { return nil, newQueueDeduplicatedError( input.status.ID, input.status.QueueName, @@ -770,25 +786,17 @@ func (s *sysDB) resumeWorkflow(ctx context.Context, workflowID string) error { return nil // Workflow is complete, do nothing } - // If the original workflow has a timeout, let's recompute a deadline - var deadline *int64 = nil - if wf.Timeout > 0 { - deadlineMs := time.Now().Add(wf.Timeout).UnixMilli() - deadline = &deadlineMs - } - // Set the workflow's status to ENQUEUED and clear its recovery attempts, set new deadline updateStatusQuery := `UPDATE dbos.workflow_status SET status = $1, queue_name = $2, recovery_attempts = $3, - workflow_deadline_epoch_ms = $4, deduplication_id = NULL, - started_at_epoch_ms = NULL, updated_at = $5 - WHERE workflow_uuid = $6` + workflow_deadline_epoch_ms = NULL, deduplication_id = NULL, + started_at_epoch_ms = NULL, updated_at = $4 + WHERE workflow_uuid = $5` _, err = tx.Exec(ctx, updateStatusQuery, WorkflowStatusEnqueued, _DBOS_INTERNAL_QUEUE_NAME, 0, - deadline, time.Now().UnixMilli(), workflowID) if err != nil { @@ -921,7 +929,7 @@ func (s *sysDB) awaitWorkflowResult(ctx context.Context, workflowID string) (any err := row.Scan(&status, &outputString, &errorStr) if err != nil { if err == pgx.ErrNoRows { - time.Sleep(1 * time.Second) + time.Sleep(_DB_RETRY_INTERVAL) continue } return nil, fmt.Errorf("failed to query workflow status: %w", err) @@ -942,7 +950,7 @@ func (s *sysDB) awaitWorkflowResult(ctx context.Context, workflowID string) (any case WorkflowStatusCancelled: return output, newAwaitedWorkflowCancelledError(workflowID) default: - time.Sleep(1 * time.Second) + time.Sleep(_DB_RETRY_INTERVAL) } } } @@ -1000,7 +1008,7 @@ func (s *sysDB) recordOperationResult(ctx context.Context, input recordOperation if err != nil { s.logger.Error("RecordOperationResult Error occurred", "error", err) - if pgErr, ok := err.(*pgconn.PgError); ok && pgErr.Code == "23505" { + if pgErr, ok := err.(*pgconn.PgError); ok && pgErr.Code == _PG_ERROR_UNIQUE_VIOLATION { return newWorkflowConflictIDError(input.workflowID) } return err @@ -1051,7 +1059,7 @@ func (s *sysDB) recordChildWorkflow(ctx context.Context, input recordChildWorkfl if err != nil { // Check for unique constraint violation (conflict ID error) - if pgErr, ok := err.(*pgconn.PgError); ok && pgErr.Code == "23505" { + if pgErr, ok := err.(*pgconn.PgError); ok && pgErr.Code == _PG_ERROR_UNIQUE_VIOLATION { return fmt.Errorf( "child workflow %s already registered for parent workflow %s (operation ID: %d)", input.childWorkflowID, input.parentWorkflowID, input.stepID) @@ -1354,17 +1362,21 @@ func (s *sysDB) sleep(ctx context.Context, duration time.Duration) (time.Duratio func (s *sysDB) notificationListenerLoop(ctx context.Context) { defer func() { - s.notificationLoopDone <- true + s.notificationLoopDone <- struct{}{} }() s.logger.Info("DBOS: Starting notification listener loop") - mrr := s.notificationListenerConnection.Exec(ctx, "LISTEN dbos_notifications_channel; LISTEN dbos_workflow_events_channel") + mrr := s.notificationListenerConnection.Exec(ctx, fmt.Sprintf("LISTEN %s; LISTEN %s", _DBOS_NOTIFICATIONS_CHANNEL, _DBOS_WORKFLOW_EVENTS_CHANNEL)) results, err := mrr.ReadAll() if err != nil { s.logger.Error("Failed to listen on notification channels", "error", err) return } - mrr.Close() + err = mrr.Close() + if err != nil { + s.logger.Error("Failed to close connection after setting notification listeners", "error", err) + return + } for _, result := range results { if result.Err != nil { @@ -1390,9 +1402,10 @@ func (s *sysDB) notificationListenerLoop(ctx context.Context) { return } - // Other errors - log and retry. XXX eventually add exponential backoff + jitter + // Other errors - log and retry. + // TODO add exponential backoff + jitter s.logger.Error("Error waiting for notification", "error", err) - time.Sleep(500 * time.Millisecond) + time.Sleep(_DB_CONNECTION_RETRY_DELAY) continue } } @@ -1465,7 +1478,7 @@ func (s *sysDB) send(ctx context.Context, input WorkflowSendInput) error { _, err = tx.Exec(ctx, insertQuery, input.DestinationID, topic, messageString) if err != nil { // Check for foreign key violation (destination workflow doesn't exist) - if pgErr, ok := err.(*pgconn.PgError); ok && pgErr.Code == "23503" { + if pgErr, ok := err.(*pgconn.PgError); ok && pgErr.Code == _PG_ERROR_FOREIGN_KEY_VIOLATION { return newNonExistentWorkflowError(input.DestinationID) } return fmt.Errorf("failed to insert notification: %w", err) @@ -1501,7 +1514,6 @@ func (s *sysDB) recv(ctx context.Context, input WorkflowRecvInput) (any, error) functionName := "DBOS.recv" // Get workflow state from context - // XXX these checks might be better suited for outside of the system db code. We'll see when we implement the client. wfState, ok := ctx.Value(workflowStateKey).(*workflowState) if !ok || wfState == nil { return nil, newStepExecutionError("", functionName, "workflow state not found in context: are you running this step within a workflow?") @@ -1521,7 +1533,6 @@ func (s *sysDB) recv(ctx context.Context, input WorkflowRecvInput) (any, error) } // Check if operation was already executed - // XXX this might not need to be in the transaction checkInput := checkOperationExecutionDBInput{ workflowID: destinationID, stepID: stepID, @@ -1548,7 +1559,6 @@ func (s *sysDB) recv(ctx context.Context, input WorkflowRecvInput) (any, error) } defer func() { // Clean up the condition variable after we're done and broadcast to wake up any waiting goroutines - // XXX We should handle panics in this function and make sure we call this. Not a problem for now as panic will crash the importing package. cond.Broadcast() s.notificationsMap.Delete(payload) }() @@ -1563,7 +1573,6 @@ func (s *sysDB) recv(ctx context.Context, input WorkflowRecvInput) (any, error) } if !exists { // Wait for notifications using condition variable with timeout pattern - // XXX should we prevent zero or negative timeouts? s.logger.Debug("Waiting for notification on condition variable", "payload", payload) done := make(chan struct{}) @@ -1783,7 +1792,7 @@ func (s *sysDB) getEvent(ctx context.Context, input WorkflowGetEventInput) (any, return nil, fmt.Errorf("failed to query workflow event: %w", err) } - if err == pgx.ErrNoRows || valueString == nil { // XXX valueString should never be `nil` + if err == pgx.ErrNoRows || valueString == nil { // valueString should never be `nil` // Wait for notification with timeout using condition variable done := make(chan struct{}) go func() { @@ -1852,8 +1861,13 @@ type dequeuedWorkflow struct { input string } -// TODO input struct -func (s *sysDB) dequeueWorkflows(ctx context.Context, queue WorkflowQueue, executorID, applicationVersion string) ([]dequeuedWorkflow, error) { +type dequeueWorkflowsInput struct { + queue WorkflowQueue + executorID string + applicationVersion string +} + +func (s *sysDB) dequeueWorkflows(ctx context.Context, input dequeueWorkflowsInput) ([]dequeuedWorkflow, error) { // Begin transaction with snapshot isolation tx, err := s.pool.Begin(ctx) if err != nil { @@ -1870,8 +1884,8 @@ func (s *sysDB) dequeueWorkflows(ctx context.Context, queue WorkflowQueue, execu // First check the rate limiter startTimeMs := time.Now().UnixMilli() var numRecentQueries int - if queue.RateLimit != nil { - limiterPeriod := time.Duration(queue.RateLimit.Period * float64(time.Second)) + if input.queue.RateLimit != nil { + limiterPeriod := time.Duration(input.queue.RateLimit.Period * float64(time.Second)) // Calculate the cutoff time: current time minus limiter period cutoffTimeMs := time.Now().Add(-limiterPeriod).UnixMilli() @@ -1885,22 +1899,22 @@ func (s *sysDB) dequeueWorkflows(ctx context.Context, queue WorkflowQueue, execu AND started_at_epoch_ms > $3` err := tx.QueryRow(ctx, limiterQuery, - queue.Name, + input.queue.Name, WorkflowStatusEnqueued, cutoffTimeMs).Scan(&numRecentQueries) if err != nil { return nil, fmt.Errorf("failed to query rate limiter: %w", err) } - if numRecentQueries >= queue.RateLimit.Limit { + if numRecentQueries >= input.queue.RateLimit.Limit { return []dequeuedWorkflow{}, nil } } // Calculate max_tasks based on concurrency limits - maxTasks := queue.MaxTasksPerIteration + maxTasks := input.queue.MaxTasksPerIteration - if queue.WorkerConcurrency != nil || queue.GlobalConcurrency != nil { + if input.queue.WorkerConcurrency != nil || input.queue.GlobalConcurrency != nil { // Count pending workflows by executor pendingQuery := ` SELECT executor_id, COUNT(*) as task_count @@ -1908,7 +1922,7 @@ func (s *sysDB) dequeueWorkflows(ctx context.Context, queue WorkflowQueue, execu WHERE queue_name = $1 AND status = $2 GROUP BY executor_id` - rows, err := tx.Query(ctx, pendingQuery, queue.Name, WorkflowStatusPending) + rows, err := tx.Query(ctx, pendingQuery, input.queue.Name, WorkflowStatusPending) if err != nil { return nil, fmt.Errorf("failed to query pending workflows: %w", err) } @@ -1924,28 +1938,28 @@ func (s *sysDB) dequeueWorkflows(ctx context.Context, queue WorkflowQueue, execu pendingWorkflowsDict[executorIDRow] = taskCount } - localPendingWorkflows := pendingWorkflowsDict[executorID] + localPendingWorkflows := pendingWorkflowsDict[input.executorID] // Check worker concurrency limit - if queue.WorkerConcurrency != nil { - workerConcurrency := *queue.WorkerConcurrency + if input.queue.WorkerConcurrency != nil { + workerConcurrency := *input.queue.WorkerConcurrency if localPendingWorkflows > workerConcurrency { - s.logger.Warn("Local pending workflows on queue exceeds worker concurrency limit", "local_pending", localPendingWorkflows, "queue_name", queue.Name, "concurrency_limit", workerConcurrency) + s.logger.Warn("Local pending workflows on queue exceeds worker concurrency limit", "local_pending", localPendingWorkflows, "queue_name", input.queue.Name, "concurrency_limit", workerConcurrency) } availableWorkerTasks := max(workerConcurrency-localPendingWorkflows, 0) maxTasks = availableWorkerTasks } // Check global concurrency limit - if queue.GlobalConcurrency != nil { + if input.queue.GlobalConcurrency != nil { globalPendingWorkflows := 0 for _, count := range pendingWorkflowsDict { globalPendingWorkflows += count } - concurrency := *queue.GlobalConcurrency + concurrency := *input.queue.GlobalConcurrency if globalPendingWorkflows > concurrency { - s.logger.Warn("Total pending workflows on queue exceeds global concurrency limit", "total_pending", globalPendingWorkflows, "queue_name", queue.Name, "concurrency_limit", concurrency) + s.logger.Warn("Total pending workflows on queue exceeds global concurrency limit", "total_pending", globalPendingWorkflows, "queue_name", input.queue.Name, "concurrency_limit", concurrency) } availableTasks := max(concurrency-globalPendingWorkflows, 0) if availableTasks < maxTasks { @@ -1957,7 +1971,7 @@ func (s *sysDB) dequeueWorkflows(ctx context.Context, queue WorkflowQueue, execu // Build the query to select workflows for dequeueing // Use SKIP LOCKED when no global concurrency is set to avoid blocking, // otherwise use NOWAIT to ensure consistent view across processes - skipLocks := queue.GlobalConcurrency == nil + skipLocks := input.queue.GlobalConcurrency == nil var lockClause string if skipLocks { lockClause = "FOR UPDATE SKIP LOCKED" @@ -1966,7 +1980,7 @@ func (s *sysDB) dequeueWorkflows(ctx context.Context, queue WorkflowQueue, execu } var query string - if queue.PriorityEnabled { + if input.queue.PriorityEnabled { query = fmt.Sprintf(` SELECT workflow_uuid FROM dbos.workflow_status @@ -1991,7 +2005,7 @@ func (s *sysDB) dequeueWorkflows(ctx context.Context, queue WorkflowQueue, execu } // Execute the query to get workflow IDs - rows, err := tx.Query(ctx, query, queue.Name, WorkflowStatusEnqueued, applicationVersion) + rows, err := tx.Query(ctx, query, input.queue.Name, WorkflowStatusEnqueued, input.applicationVersion) if err != nil { return nil, fmt.Errorf("failed to query enqueued workflows: %w", err) } @@ -2014,15 +2028,15 @@ func (s *sysDB) dequeueWorkflows(ctx context.Context, queue WorkflowQueue, execu } if len(dequeuedIDs) > 0 { - s.logger.Debug("attempting to dequeue task(s)", "queueName", queue.Name, "numTasks", len(dequeuedIDs)) + s.logger.Debug("attempting to dequeue task(s)", "queueName", input.queue.Name, "numTasks", len(dequeuedIDs)) } // Update workflows to PENDING status and get their details var retWorkflows []dequeuedWorkflow for _, id := range dequeuedIDs { // If we have a limiter, stop dequeueing workflows when the number of workflows started this period exceeds the limit. - if queue.RateLimit != nil { - if len(retWorkflows)+numRecentQueries >= queue.RateLimit.Limit { + if input.queue.RateLimit != nil { + if len(retWorkflows)+numRecentQueries >= input.queue.RateLimit.Limit { break } } @@ -2048,8 +2062,8 @@ func (s *sysDB) dequeueWorkflows(ctx context.Context, queue WorkflowQueue, execu var inputString *string err := tx.QueryRow(ctx, updateQuery, WorkflowStatusPending, - applicationVersion, - executorID, + input.applicationVersion, + input.executorID, startTimeMs, id).Scan(&retWorkflow.name, &inputString) diff --git a/dbos/workflow.go b/dbos/workflow.go index d8de1d2b..b1657541 100644 --- a/dbos/workflow.go +++ b/dbos/workflow.go @@ -89,11 +89,57 @@ type WorkflowHandle[R any] interface { GetWorkflowID() string // Get the unique workflow identifier } +// baseWorkflowHandle contains common fields and methods for workflow handles +type baseWorkflowHandle struct { + workflowID string + dbosContext DBOSContext +} + +// GetStatus returns the current status of the workflow from the database +func (h *baseWorkflowHandle) GetStatus() (WorkflowStatus, error) { + workflowStatuses, err := h.dbosContext.(*dbosContext).systemDB.listWorkflows(h.dbosContext, listWorkflowsDBInput{ + workflowIDs: []string{h.workflowID}, + loadInput: true, + loadOutput: true, + }) + if err != nil { + return WorkflowStatus{}, fmt.Errorf("failed to get workflow status: %w", err) + } + if len(workflowStatuses) == 0 { + return WorkflowStatus{}, newNonExistentWorkflowError(h.workflowID) + } + return workflowStatuses[0], nil +} + +func (h *baseWorkflowHandle) GetWorkflowID() string { + return h.workflowID +} + +// newWorkflowHandle creates a new workflowHandle with the given parameters +func newWorkflowHandle[R any](ctx DBOSContext, workflowID string, outcomeChan chan workflowOutcome[R]) *workflowHandle[R] { + return &workflowHandle[R]{ + baseWorkflowHandle: baseWorkflowHandle{ + workflowID: workflowID, + dbosContext: ctx, + }, + outcomeChan: outcomeChan, + } +} + +// newWorkflowPollingHandle creates a new workflowPollingHandle with the given parameters +func newWorkflowPollingHandle[R any](ctx DBOSContext, workflowID string) *workflowPollingHandle[R] { + return &workflowPollingHandle[R]{ + baseWorkflowHandle: baseWorkflowHandle{ + workflowID: workflowID, + dbosContext: ctx, + }, + } +} + // workflowHandle is a concrete implementation of WorkflowHandle type workflowHandle[R any] struct { - workflowID string + baseWorkflowHandle outcomeChan chan workflowOutcome[R] - dbosContext DBOSContext } // GetResult waits for the workflow to complete and returns the result @@ -127,29 +173,8 @@ func (h *workflowHandle[R]) GetResult() (R, error) { return outcome.result, outcome.err } -// GetStatus returns the current status of the workflow from the database -func (h *workflowHandle[R]) GetStatus() (WorkflowStatus, error) { - workflowStatuses, err := h.dbosContext.(*dbosContext).systemDB.listWorkflows(h.dbosContext, listWorkflowsDBInput{ - workflowIDs: []string{h.workflowID}, - loadInput: true, - loadOutput: true, - }) - if err != nil { - return WorkflowStatus{}, fmt.Errorf("failed to get workflow status: %w", err) - } - if len(workflowStatuses) == 0 { - return WorkflowStatus{}, newNonExistentWorkflowError(h.workflowID) - } - return workflowStatuses[0], nil -} - -func (h *workflowHandle[R]) GetWorkflowID() string { - return h.workflowID -} - type workflowPollingHandle[R any] struct { - workflowID string - dbosContext DBOSContext + baseWorkflowHandle } func (h *workflowPollingHandle[R]) GetResult() (R, error) { @@ -157,7 +182,6 @@ func (h *workflowPollingHandle[R]) GetResult() (R, error) { if result != nil { typedResult, ok := result.(R) if !ok { - // TODO check what this looks like in practice return *new(R), newWorkflowUnexpectedResultType(h.workflowID, fmt.Sprintf("%T", new(R)), fmt.Sprintf("%T", result)) } // If we are calling GetResult inside a workflow, record the result as a step result @@ -177,8 +201,8 @@ func (h *workflowPollingHandle[R]) GetResult() (R, error) { } recordResultErr := h.dbosContext.(*dbosContext).systemDB.recordChildGetResult(h.dbosContext, recordGetResultInput) if recordResultErr != nil { - // XXX do we want to fail this? h.dbosContext.(*dbosContext).logger.Error("failed to record get result", "error", recordResultErr) + return *new(R), newWorkflowExecutionError(workflowState.workflowID, fmt.Sprintf("recording child workflow result: %v", recordResultErr)) } } return typedResult, err @@ -186,26 +210,6 @@ func (h *workflowPollingHandle[R]) GetResult() (R, error) { return *new(R), err } -// GetStatus returns the current status of the workflow from the database -func (h *workflowPollingHandle[R]) GetStatus() (WorkflowStatus, error) { - workflowStatuses, err := h.dbosContext.(*dbosContext).systemDB.listWorkflows(h.dbosContext, listWorkflowsDBInput{ - workflowIDs: []string{h.workflowID}, - loadInput: true, - loadOutput: true, - }) - if err != nil { - return WorkflowStatus{}, fmt.Errorf("failed to get workflow status: %w", err) - } - if len(workflowStatuses) == 0 { - return WorkflowStatus{}, newNonExistentWorkflowError(h.workflowID) - } - return workflowStatuses[0], nil -} - -func (h *workflowPollingHandle[R]) GetWorkflowID() string { - return h.workflowID -} - /**********************************/ /******* WORKFLOW REGISTRY *******/ /**********************************/ @@ -230,21 +234,18 @@ func registerWorkflow(ctx DBOSContext, workflowFQN string, fn WrappedWorkflowFun panic("Cannot register workflow after DBOS has launched") } - c.workflowRegMutex.Lock() - defer c.workflowRegMutex.Unlock() - - if _, exists := c.workflowRegistry[workflowFQN]; exists { - c.logger.Error("workflow function already registered", "fqn", workflowFQN) - panic(newConflictingRegistrationError(workflowFQN)) - } - - // We must keep the registry indexed by FQN (because RunAsWorkflow uses reflection to find the function name and uses that to look it up in the registry) - c.workflowRegistry[workflowFQN] = workflowRegistryEntry{ + // Check if workflow already exists and store atomically using LoadOrStore + entry := workflowRegistryEntry{ wrappedFunction: fn, maxRetries: maxRetries, name: customName, } + if _, exists := c.workflowRegistry.LoadOrStore(workflowFQN, entry); exists { + c.logger.Error("workflow function already registered", "fqn", workflowFQN) + panic(newConflictingRegistrationError(workflowFQN)) + } + // We need to get a mapping from custom name to FQN for registry lookups that might not know the FQN (queue, recovery) if len(customName) > 0 { c.workflowCustomNametoFQN.Store(customName, workflowFQN) @@ -278,13 +279,16 @@ func registerScheduledWorkflow(ctx DBOSContext, workflowName string, fn Workflow // Use Next if Prev is not set, which will only happen for the first run scheduledTime = entry.Next } - wfID := fmt.Sprintf("sched-%s-%s", workflowName, scheduledTime) // XXX we can rethink the format + wfID := fmt.Sprintf("sched-%s-%s", workflowName, scheduledTime) opts := []WorkflowOption{ WithWorkflowID(wfID), WithQueue(_DBOS_INTERNAL_QUEUE_NAME), withWorkflowName(workflowName), } - ctx.RunAsWorkflow(ctx, fn, scheduledTime, opts...) + _, err := ctx.RunAsWorkflow(ctx, fn, scheduledTime, opts...) + if err != nil { + c.logger.Error("failed to run scheduled workflow", "fqn", workflowName, "error", err) + } }) if err != nil { panic(fmt.Sprintf("failed to register scheduled workflow: %v", err)) @@ -302,6 +306,11 @@ type workflowRegistrationOption func(*workflowRegistrationParams) const ( _DEFAULT_MAX_RECOVERY_ATTEMPTS = 100 + + // Step retry defaults + _DEFAULT_STEP_BASE_INTERVAL = 100 * time.Millisecond + _DEFAULT_STEP_MAX_INTERVAL = 5 * time.Second + _DEFAULT_STEP_BACKOFF_FACTOR = 2.0 ) // WithMaxRetries sets the maximum number of retry attempts for workflow recovery. @@ -380,28 +389,32 @@ func RegisterWorkflow[P any, R any](ctx DBOSContext, fn GenericWorkflowFunc[P, R // Register a type-erased version of the durable workflow for recovery typedErasedWorkflow := WorkflowFunc(func(ctx DBOSContext, input any) (any, error) { - // This type check is redundant with the one in the wrapper, but I'd better be safe than sorry typedInput, ok := input.(P) if !ok { - // FIXME: we need to record the error in the database here + wfID, err := ctx.GetWorkflowID() + if err != nil { + return nil, fmt.Errorf("failed to get workflow ID: %w", err) + } + err = ctx.(*dbosContext).systemDB.updateWorkflowOutcome(WithoutCancel(ctx), updateWorkflowOutcomeDBInput{ + workflowID: wfID, + status: WorkflowStatusError, + err: newWorkflowUnexpectedInputType(fqn, fmt.Sprintf("%T", typedInput), fmt.Sprintf("%T", input)), + }) + if err != nil { + return nil, fmt.Errorf("failed to record unexpected input type error: %w", err) + } return nil, newWorkflowUnexpectedInputType(fqn, fmt.Sprintf("%T", typedInput), fmt.Sprintf("%T", input)) } return fn(ctx, typedInput) }) typeErasedWrapper := WrappedWorkflowFunc(func(ctx DBOSContext, input any, opts ...WorkflowOption) (WorkflowHandle[any], error) { - typedInput, ok := input.(P) - if !ok { - // FIXME: we need to record the error in the database here - return nil, newWorkflowUnexpectedInputType(fqn, fmt.Sprintf("%T", typedInput), fmt.Sprintf("%T", input)) - } - opts = append(opts, withWorkflowName(fqn)) // Append the name so ctx.RunAsWorkflow can look it up from the registry to apply registration-time options - handle, err := ctx.RunAsWorkflow(ctx, typedErasedWorkflow, typedInput, opts...) + handle, err := ctx.RunAsWorkflow(ctx, typedErasedWorkflow, input, opts...) if err != nil { return nil, err } - return &workflowPollingHandle[any]{workflowID: handle.GetWorkflowID(), dbosContext: ctx}, nil // this is only used by recovery and queue runner so far -- queue runner dismisses it + return newWorkflowPollingHandle[any](ctx, handle.GetWorkflowID()), nil // this is only used by recovery and queue runner so far -- queue runner dismisses it }) registerWorkflow(ctx, fqn, typeErasedWrapper, registrationParams.maxRetries, registrationParams.name) @@ -528,10 +541,7 @@ func RunAsWorkflow[P any, R any](ctx DBOSContext, fn GenericWorkflowFunc[P, R], // If we got a polling handle, return its typed version if pollingHandle, ok := handle.(*workflowPollingHandle[any]); ok { // We need to convert the polling handle to a typed handle - typedPollingHandle := &workflowPollingHandle[R]{ - workflowID: pollingHandle.workflowID, - dbosContext: pollingHandle.dbosContext, - } + typedPollingHandle := newWorkflowPollingHandle[R](pollingHandle.dbosContext, pollingHandle.workflowID) return typedPollingHandle, nil } @@ -559,11 +569,7 @@ func RunAsWorkflow[P any, R any](ctx DBOSContext, fn GenericWorkflowFunc[P, R], } }() - typedHandle := &workflowHandle[R]{ - workflowID: handle.workflowID, - outcomeChan: typedOutcomeChan, - dbosContext: handle.dbosContext, - } + typedHandle := newWorkflowHandle[R](handle.dbosContext, handle.workflowID, typedOutcomeChan) return typedHandle, nil } @@ -582,10 +588,14 @@ func (c *dbosContext) RunAsWorkflow(_ DBOSContext, fn WorkflowFunc, input any, o } // Lookup the registry for registration-time options - registeredWorkflow, exists := c.workflowRegistry[params.workflowName] + registeredWorkflowAny, exists := c.workflowRegistry.Load(params.workflowName) if !exists { return nil, newNonExistentWorkflowError(params.workflowName) } + registeredWorkflow, ok := registeredWorkflowAny.(workflowRegistryEntry) + if !ok { + return nil, fmt.Errorf("invalid workflow registry entry type for workflow %s", params.workflowName) + } if registeredWorkflow.maxRetries > 0 { params.maxRetries = registeredWorkflow.maxRetries } @@ -597,6 +607,11 @@ func (c *dbosContext) RunAsWorkflow(_ DBOSContext, fn WorkflowFunc, input any, o parentWorkflowState, ok := c.Value(workflowStateKey).(*workflowState) isChildWorkflow := ok && parentWorkflowState != nil + // Prevent spawning child workflows from within a step + if isChildWorkflow && parentWorkflowState.isWithinStep { + return nil, newStepExecutionError(parentWorkflowState.workflowID, params.workflowName, "cannot spawn child workflow from within a step") + } + if isChildWorkflow { // Advance step ID if we are a child workflow parentWorkflowState.NextStepID() @@ -626,7 +641,7 @@ func (c *dbosContext) RunAsWorkflow(_ DBOSContext, fn WorkflowFunc, input any, o return nil, newWorkflowExecutionError(parentWorkflowState.workflowID, fmt.Sprintf("checking child workflow: %v", err)) } if childWorkflowID != nil { - return &workflowPollingHandle[any]{workflowID: *childWorkflowID, dbosContext: uncancellableCtx}, nil + return newWorkflowPollingHandle[any](uncancellableCtx, *childWorkflowID), nil } } @@ -637,23 +652,27 @@ func (c *dbosContext) RunAsWorkflow(_ DBOSContext, fn WorkflowFunc, input any, o status = WorkflowStatusPending } - // Check if the user-provided context has a deadline + // Compute the timeout based on the context deadline, if any deadline, ok := c.Deadline() if !ok { deadline = time.Time{} // No deadline set } - - // Compute the timeout based on the deadline var timeout time.Duration if !deadline.IsZero() { timeout = time.Until(deadline) - /* unclear to me if this is a real use case: + // The timeout could be in the past, for small deadlines, to propagation delays. If so set it to a minimal value if timeout < 0 { - return nil, newWorkflowExecutionError(workflowID, "deadline is in the past") + timeout = 1 * time.Millisecond } - */ + } + // When enqueuing, we do not set a deadline. It'll be computed with the timeout during dequeue. + if status == WorkflowStatusEnqueued { + deadline = time.Time{} } + if params.priority > uint(math.MaxInt) { + return nil, fmt.Errorf("priority %d exceeds maximum allowed value %d", params.priority, math.MaxInt) + } workflowStatus := WorkflowStatus{ Name: params.workflowName, ApplicationVersion: params.applicationVersion, @@ -695,7 +714,7 @@ func (c *dbosContext) RunAsWorkflow(_ DBOSContext, fn WorkflowFunc, input any, o if err := tx.Commit(uncancellableCtx); err != nil { return nil, newWorkflowExecutionError(workflowID, fmt.Sprintf("failed to commit transaction: %v", err)) } - return &workflowPollingHandle[any]{workflowID: workflowStatus.ID, dbosContext: uncancellableCtx}, nil + return newWorkflowPollingHandle[any](uncancellableCtx, workflowStatus.ID), nil } // Record child workflow relationship if this is a child workflow @@ -728,11 +747,19 @@ func (c *dbosContext) RunAsWorkflow(_ DBOSContext, fn WorkflowFunc, input any, o workflowCtx := WithValue(c, workflowStateKey, wfState) - // If the workflow has a durable deadline, set it in the context. + // If the workflow has a timeout but no deadline, compute the deadline from the timeout. + // Else use the durable deadline. + durableDeadline := time.Time{} + if insertStatusResult.timeout > 0 && insertStatusResult.workflowDeadline.IsZero() { + durableDeadline = time.Now().Add(insertStatusResult.timeout) + } else if !insertStatusResult.workflowDeadline.IsZero() { + durableDeadline = insertStatusResult.workflowDeadline + } + var stopFunc func() bool cancelFuncCompleted := make(chan struct{}) - if !insertStatusResult.workflowDeadline.IsZero() { - workflowCtx, _ = WithTimeout(workflowCtx, time.Until(insertStatusResult.workflowDeadline)) + if !durableDeadline.IsZero() { + workflowCtx, _ = WithTimeout(workflowCtx, time.Until(durableDeadline)) // Register a cancel function that cancels the workflow in the DB as soon as the context is cancelled dbosCancelFunction := func() { c.logger.Info("Cancelling workflow", "workflow_id", workflowID) @@ -787,7 +814,7 @@ func (c *dbosContext) RunAsWorkflow(_ DBOSContext, fn WorkflowFunc, input any, o close(outcomeChan) }() - return &workflowHandle[any]{workflowID: workflowID, outcomeChan: outcomeChan, dbosContext: uncancellableCtx}, nil + return newWorkflowHandle[any](uncancellableCtx, workflowID, outcomeChan), nil } /******************************/ @@ -819,9 +846,9 @@ func setStepParamDefaults(params *StepParams, stepName string) *StepParams { if params == nil { return &StepParams{ MaxRetries: 0, // Default to no retries - BackoffFactor: 2.0, - BaseInterval: 100 * time.Millisecond, // Default base interval - MaxInterval: 5 * time.Second, // Default max interval + BackoffFactor: _DEFAULT_STEP_BACKOFF_FACTOR, + BaseInterval: _DEFAULT_STEP_BASE_INTERVAL, // Default base interval + MaxInterval: _DEFAULT_STEP_MAX_INTERVAL, // Default max interval StepName: func() string { if value, ok := typeErasedStepNameToStepName.Load(stepName); ok { return value.(string) @@ -833,13 +860,13 @@ func setStepParamDefaults(params *StepParams, stepName string) *StepParams { // Set defaults for zero values if params.BackoffFactor == 0 { - params.BackoffFactor = 2.0 // Default backoff factor + params.BackoffFactor = _DEFAULT_STEP_BACKOFF_FACTOR // Default backoff factor } if params.BaseInterval == 0 { - params.BaseInterval = 100 * time.Millisecond // Default base interval + params.BaseInterval = _DEFAULT_STEP_BASE_INTERVAL // Default base interval } if params.MaxInterval == 0 { - params.MaxInterval = 5 * time.Second // Default max interval + params.MaxInterval = _DEFAULT_STEP_MAX_INTERVAL // Default max interval } if len(params.StepName) == 0 { // If the step name is not provided, use the function name @@ -918,25 +945,24 @@ func RunAsStep[R any](ctx DBOSContext, fn GenericStepFunc[R]) (R, error) { } func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc) (any, error) { + // Look up for step parameters in the context and set defaults + params, ok := c.Value(StepParamsKey).(*StepParams) + if !ok { + params = nil + } + params = setStepParamDefaults(params, runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name()) + // Get workflow state from context wfState, ok := c.Value(workflowStateKey).(*workflowState) if !ok || wfState == nil { - // TODO: try to print step name - return nil, newStepExecutionError("", "", "workflow state not found in context: are you running this step within a workflow?") + return nil, newStepExecutionError("", params.StepName, "workflow state not found in context: are you running this step within a workflow?") } // This should not happen when called from the package-level RunAsStep if fn == nil { - return nil, newStepExecutionError(wfState.workflowID, "", "step function cannot be nil") + return nil, newStepExecutionError(wfState.workflowID, params.StepName, "step function cannot be nil") } - // Look up for step parameters in the context and set defaults - params, ok := c.Value(StepParamsKey).(*StepParams) - if !ok { - params = nil - } - params = setStepParamDefaults(params, runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name()) - // If within a step, just run the function directly if wfState.isWithinStep { return fn(c) @@ -1292,7 +1318,7 @@ func (c *dbosContext) RetrieveWorkflow(_ DBOSContext, workflowID string) (Workfl if len(workflowStatus) == 0 { return nil, newNonExistentWorkflowError(workflowID) } - return &workflowPollingHandle[any]{workflowID: workflowID, dbosContext: c}, nil + return newWorkflowPollingHandle[any](c, workflowID), nil } // RetrieveWorkflow returns a typed handle to an existing workflow. @@ -1312,9 +1338,9 @@ func (c *dbosContext) RetrieveWorkflow(_ DBOSContext, workflowID string) (Workfl // } else { // log.Printf("Result: %d", result) // } -func RetrieveWorkflow[R any](ctx DBOSContext, workflowID string) (workflowPollingHandle[R], error) { +func RetrieveWorkflow[R any](ctx DBOSContext, workflowID string) (*workflowPollingHandle[R], error) { if ctx == nil { - return workflowPollingHandle[R]{}, errors.New("dbosCtx cannot be nil") + return nil, errors.New("dbosCtx cannot be nil") } // Register the output for gob encoding @@ -1325,12 +1351,12 @@ func RetrieveWorkflow[R any](ctx DBOSContext, workflowID string) (workflowPollin workflowIDs: []string{workflowID}, }) if err != nil { - return workflowPollingHandle[R]{}, fmt.Errorf("failed to retrieve workflow status: %w", err) + return nil, fmt.Errorf("failed to retrieve workflow status: %w", err) } if len(workflowStatus) == 0 { - return workflowPollingHandle[R]{}, newNonExistentWorkflowError(workflowID) + return nil, newNonExistentWorkflowError(workflowID) } - return workflowPollingHandle[R]{workflowID: workflowID, dbosContext: ctx}, nil + return newWorkflowPollingHandle[R](ctx, workflowID), nil } type EnqueueOptions struct { @@ -1355,6 +1381,9 @@ func (c *dbosContext) Enqueue(_ DBOSContext, params EnqueueOptions) (WorkflowHan deadline = time.Now().Add(params.WorkflowTimeout) } + if params.Priority > uint(math.MaxInt) { + return nil, fmt.Errorf("priority %d exceeds maximum allowed value %d", params.Priority, math.MaxInt) + } status := WorkflowStatus{ Name: params.WorkflowName, ApplicationVersion: params.ApplicationVersion, @@ -1392,10 +1421,7 @@ func (c *dbosContext) Enqueue(_ DBOSContext, params EnqueueOptions) (WorkflowHan return nil, fmt.Errorf("failed to commit transaction: %w", err) } - return &workflowPollingHandle[any]{ - workflowID: workflowID, - dbosContext: uncancellableCtx, - }, nil + return newWorkflowPollingHandle[any](uncancellableCtx, workflowID), nil } type GenericEnqueueOptions[P any] struct { @@ -1490,10 +1516,7 @@ func Enqueue[P any, R any](ctx DBOSContext, params GenericEnqueueOptions[P]) (Wo return nil, err } - return &workflowPollingHandle[R]{ - workflowID: handle.GetWorkflowID(), - dbosContext: ctx, - }, nil + return newWorkflowPollingHandle[R](ctx, handle.GetWorkflowID()), nil } // CancelWorkflow cancels a running or enqueued workflow by setting its status to CANCELLED. @@ -1534,7 +1557,7 @@ func (c *dbosContext) ResumeWorkflow(_ DBOSContext, workflowID string) (Workflow if err != nil { return nil, err } - return &workflowPollingHandle[any]{workflowID: workflowID, dbosContext: c}, nil + return newWorkflowPollingHandle[any](c, workflowID), nil } // ResumeWorkflow resumes a cancelled workflow by setting its status back to ENQUEUED. @@ -1569,7 +1592,7 @@ func ResumeWorkflow[R any](ctx DBOSContext, workflowID string) (WorkflowHandle[R if err != nil { return nil, err } - return &workflowPollingHandle[R]{workflowID: workflowID, dbosContext: ctx}, nil + return newWorkflowPollingHandle[R](ctx, workflowID), nil } // ForkWorkflowInput holds configuration parameters for forking workflows. @@ -1592,6 +1615,9 @@ func (c *dbosContext) ForkWorkflow(_ DBOSContext, input ForkWorkflowInput) (Work } // Create input for system database + if input.StartStep > uint(math.MaxInt) { + return nil, fmt.Errorf("start step too large: %d", input.StartStep) + } dbInput := forkWorkflowDBInput{ originalWorkflowID: input.OriginalWorkflowID, forkedWorkflowID: forkedWorkflowID, @@ -1605,10 +1631,7 @@ func (c *dbosContext) ForkWorkflow(_ DBOSContext, input ForkWorkflowInput) (Work return nil, err } - return &workflowPollingHandle[any]{ - workflowID: forkedWorkflowID, - dbosContext: c, - }, nil + return newWorkflowPollingHandle[any](c, forkedWorkflowID), nil } // ForkWorkflow creates a new workflow instance by copying an existing workflow from a specific step. @@ -1662,10 +1685,7 @@ func ForkWorkflow[R any](ctx DBOSContext, input ForkWorkflowInput) (WorkflowHand if err != nil { return nil, err } - return &workflowPollingHandle[R]{ - workflowID: handle.GetWorkflowID(), - dbosContext: ctx, - }, nil + return newWorkflowPollingHandle[R](ctx, handle.GetWorkflowID()), nil } // listWorkflowsParams holds configuration parameters for listing workflows diff --git a/dbos/workflows_test.go b/dbos/workflows_test.go index 73fbf868..8457ab1e 100644 --- a/dbos/workflows_test.go +++ b/dbos/workflows_test.go @@ -6,7 +6,7 @@ Test workflow and steps features [x] workflow idempotency [x] workflow DLQ [x] workflow conflicting name -[] workflow timeouts & deadlines (including child workflows) +[x] workflow timeouts & deadlines (including child workflows) */ import ( @@ -831,6 +831,48 @@ func TestChildWorkflow(t *testing.T) { t.Fatalf("expected child result '%s', got '%s'", result, childResult) } }) + + t.Run("ChildWorkflowCannotBeSpawnedFromStep", func(t *testing.T) { + // Child workflow for testing + childWf := func(dbosCtx DBOSContext, input string) (string, error) { + return "child-result", nil + } + RegisterWorkflow(dbosCtx, childWf) + + // Step that tries to spawn a child workflow - this should fail + stepThatSpawnsChild := func(ctx context.Context, input string) (string, error) { + dbosCtx := ctx.(DBOSContext) + _, err := RunAsWorkflow(dbosCtx, childWf, input) + if err != nil { + return "", err + } + return "should-not-reach", nil + } + + // Workflow that calls the step + parentWf := func(ctx DBOSContext, input string) (string, error) { + return RunAsStep(ctx, func(context context.Context) (string, error) { + return stepThatSpawnsChild(context, input) + }) + } + RegisterWorkflow(dbosCtx, parentWf) + + // Execute the workflow - should fail when step tries to spawn child workflow + handle, err := RunAsWorkflow(dbosCtx, parentWf, "test-input") + require.NoError(t, err, "failed to start parent workflow") + + // Expect the workflow to fail + _, err = handle.GetResult() + require.Error(t, err, "expected error when spawning child workflow from step, but got none") + + // Check the error type and message + dbosErr, ok := err.(*DBOSError) + require.True(t, ok, "expected error to be of type *DBOSError, got %T", err) + require.Equal(t, StepExecutionError, dbosErr.Code, "expected error code to be StepExecutionError, got %v", dbosErr.Code) + + expectedMessagePart := "cannot spawn child workflow from within a step" + require.Contains(t, err.Error(), expectedMessagePart, "expected error message to contain %q, but got %q", expectedMessagePart, err.Error()) + }) } // Idempotency workflows moved to test functions @@ -1072,60 +1114,51 @@ func TestWorkflowDeadLetterQueue(t *testing.T) { require.Error(t, err, "expected dead letter queue error when restarting workflow with same ID but got none") dbosErr, ok = err.(*DBOSError) - if !ok { - t.Fatalf("expected DBOSError, got %T", err) - } - if dbosErr.Code != DeadLetterQueueError { - t.Fatalf("expected DeadLetterQueueError, got %v", dbosErr.Code) - } + require.True(t, ok, "expected error to be of type *DBOSError, got %T", err) + require.Equal(t, dbosErr.Code, DeadLetterQueueError, "expected error code to be DeadLetterQueueError") - // Unlock the workflow to allow it to complete - deadLetterQueueEvent.Set() - /* - // TODO: test resume when implemented - resumedHandle, err := ... + // Now resume the workflow -- this clears the DLQ status + resumedHandle, err := ResumeWorkflow[int](dbosCtx, wfID) + require.NoError(t, err, "failed to resume workflow") - // Recover pending workflows again - should work without error - _, err = recoverPendingWorkflows(executor.(*dbosContext), []string{"local"}) - if err != nil { - t.Fatalf("failed to recover pending workflows after resume: %v", err) - } + // Recover pending workflows again - should work without error + _, err = recoverPendingWorkflows(dbosCtx.(*dbosContext), []string{"local"}) + require.NoError(t, err, "failed to recover pending workflows after resume") - // Complete the blocked workflow - deadLetterQueueEvent.Set() + // Complete the blocked workflow + deadLetterQueueEvent.Set() - // Wait for both handles to complete - result1, err = handle.GetResult(context.Background()) - if err != nil { - t.Fatalf("failed to get result from original handle: %v", err) - } + // Wait for both handles to complete + result1, err := handle.GetResult() + if err != nil { + t.Fatalf("failed to get result from original handle: %v", err) + } - result2, err := resumedHandle.GetResult(context.Background()) - if err != nil { - t.Fatalf("failed to get result from resumed handle: %v", err) - } + result2, err := resumedHandle.GetResult() + if err != nil { + t.Fatalf("failed to get result from resumed handle: %v", err) + } - if result1 != result2 { - t.Fatalf("expected both handles to return same result, got %v and %v", result1, result2) - } + if result1 != result2 { + t.Fatalf("expected both handles to return same result, got %v and %v", result1, result2) + } - // Verify workflow status is SUCCESS - status, err = handle.GetStatus() - if err != nil { - t.Fatalf("failed to get final workflow status: %v", err) - } - if status.Status != WorkflowStatusSuccess { - t.Fatalf("expected workflow status to be SUCCESS, got %v", status.Status) - } + // Verify workflow status is SUCCESS + status, err = handle.GetStatus() + if err != nil { + t.Fatalf("failed to get final workflow status: %v", err) + } + if status.Status != WorkflowStatusSuccess { + t.Fatalf("expected workflow status to be SUCCESS, got %v", status.Status) + } - // Verify that retries of a completed workflow do not raise the DLQ exception - for i := 0; i < maxRecoveryAttempts*2; i++ { - _, err = RunAsWorkflow(executor, deadLetterQueueWorkflow, "test", WithWorkflowID(wfID)) - if err != nil { - t.Fatalf("unexpected error when retrying completed workflow: %v", err) - } + // Verify that retries of a completed workflow do not raise the DLQ exception + for i := 0; i < maxRecoveryAttempts*2; i++ { + _, err = RunAsWorkflow(dbosCtx, deadLetterQueueWorkflow, "test", WithWorkflowID(wfID)) + if err != nil { + t.Fatalf("unexpected error when retrying completed workflow: %v", err) } - */ + } }) t.Run("InfiniteRetriesWorkflow", func(t *testing.T) { @@ -1230,7 +1263,7 @@ func TestScheduledWorkflows(t *testing.T) { // Verify timing - each execution should be approximately 1 second apart scheduleInterval := 1 * time.Second - allowedSlack := 2 * time.Second + allowedSlack := 3 * time.Second for i, execTime := range executionTimes { // Calculate expected execution time based on schedule interval @@ -1423,7 +1456,7 @@ func TestSendRecv(t *testing.T) { if result != "message1-message2-message3" { t.Fatalf("expected received message to be 'message1-message2-message3', got '%s'", result) } - // XXX This is not a great condition: when all the tests run there's quite some randomness to this + // FIXME This is not a great condition: when all the tests run there's quite some randomness to this if time.Since(start) > 10*time.Second { t.Fatalf("receive workflow took too long to complete, expected < 5s, got %v", time.Since(start)) } @@ -2904,7 +2937,7 @@ func TestWorkflowTimeout(t *testing.T) { } // Check the deadline on the status was is within an expected range (start time + timeout * .1) - // XXX this might be flaky and frankly not super useful + // FIXME this might be flaky and frankly not super useful expectedDeadline := start.Add(timeout * 10 / 100) if status.Deadline.Before(expectedDeadline) || status.Deadline.After(start.Add(timeout)) { t.Fatalf("expected workflow deadline to be within %v and %v, got %v", expectedDeadline, start.Add(timeout), status.Deadline) diff --git a/go.mod b/go.mod index 2b3e463f..b42de03c 100644 --- a/go.mod +++ b/go.mod @@ -2,6 +2,8 @@ module github.com/dbos-inc/dbos-transact-go go 1.23.0 +toolchain go1.25.0 + require ( github.com/golang-migrate/migrate/v4 v4.18.3 github.com/google/uuid v1.6.0