diff --git a/.github/workflows/chaos-tests.yml b/.github/workflows/chaos-tests.yml new file mode 100644 index 00000000..07114440 --- /dev/null +++ b/.github/workflows/chaos-tests.yml @@ -0,0 +1,44 @@ +name: Run Chaos Tests +on: + schedule: + # Runs every hour on the hour + - cron: '0 * * * *' + push: + branches: + - main + pull_request: + branches: + types: + - ready_for_review + - opened + - reopened + - synchronize + workflow_dispatch: + +jobs: + chaos-test: + runs-on: ubuntu-latest + timeout-minutes: 60 + steps: + - uses: actions/checkout@v5 + with: + fetch-depth: 0 + fetch-tags: true + + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version: '1.25.x' + + - name: Download dependencies + run: go mod download + + - name: Install gotestsum + run: go install gotest.tools/gotestsum@latest + + - name: Run chaos tests + run: go vet ./... && go test -v -race -timeout 60m -count=1 ./... + working-directory: ./chaos_tests + env: + PGPASSWORD: a!b@c$d()e*_,/:;=?@ff[]22 + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} \ No newline at end of file diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 762c76ce..344a242d 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -44,16 +44,6 @@ jobs: with: go-version: '1.25.x' - - name: Cache Go modules - uses: actions/cache@v4 - with: - path: | - ~/go/pkg/mod - ~/.cache/go-build - key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} - restore-keys: | - ${{ runner.os }}-go- - - name: Download dependencies run: go mod download diff --git a/chaos_tests/chaos_test.go b/chaos_tests/chaos_test.go new file mode 100644 index 00000000..e23c2b6f --- /dev/null +++ b/chaos_tests/chaos_test.go @@ -0,0 +1,475 @@ +package chaos_test + +import ( + "context" + "fmt" + "log/slog" + "math/rand" + "net/url" + "os" + "os/exec" + "path/filepath" + "runtime" + "sync" + "testing" + "time" + + "github.com/dbos-inc/dbos-transact-golang/dbos" + "github.com/google/uuid" + "github.com/jackc/pgx/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var testCLIPath string + +// TestMain builds the CLI once for all tests +func TestMain(m *testing.M) { + // Get the directory where this test file is located + _, filename, _, ok := runtime.Caller(0) + if !ok { + fmt.Fprintf(os.Stderr, "Failed to get current file path\n") + os.Exit(1) + } + + // Navigate to the project root then to cmd/dbos + testDir := filepath.Dir(filename) + projectRoot := filepath.Dir(testDir) // Go up from integration/ to project root + cmdDir := filepath.Join(projectRoot, "cmd", "dbos") + + // Build output path in the integration directory (where test is) + cliPath := filepath.Join(testDir, "dbos-cli-test") + + // Delete any existing binary before building + os.Remove(cliPath) + + // Build the CLI from the cmd/dbos directory + buildCmd := exec.Command("go", "build", "-o", cliPath, ".") + buildCmd.Dir = cmdDir + buildOutput, buildErr := buildCmd.CombinedOutput() + if buildErr != nil { + fmt.Fprintf(os.Stderr, "Failed to build CLI: %s\n", string(buildOutput)) + os.Exit(1) + } + + // Set the global CLI path + absPath, err := filepath.Abs(cliPath) + if err != nil { + fmt.Fprintf(os.Stderr, "Failed to get absolute path: %v\n", err) + os.Exit(1) + } + testCLIPath = absPath + + // Start postgres + startPostgresCmd := exec.Command(cliPath, "postgres", "start") + startOutput, startErr := startPostgresCmd.CombinedOutput() + if startErr != nil { + fmt.Fprintf(os.Stderr, "Failed to start postgres: %s\n", string(startOutput)) + os.Exit(1) + } + + // Run tests + code := m.Run() + + // Clean up CLI binary + os.Remove(cliPath) + + os.Exit(code) +} + +// Use the DBOS CLI to start postgres +func startPostgres(t *testing.T, cliPath string) { + cmd := exec.Command(cliPath, "postgres", "start") + output, err := cmd.CombinedOutput() + require.NoError(t, err, "Failed to start postgres: %s", string(output)) +} + +// Use the DBOS CLI to stop postgres +func stopPostgres(t *testing.T, cliPath string) { + cmd := exec.Command(cliPath, "postgres", "stop") + output, err := cmd.CombinedOutput() + require.NoError(t, err, "Failed to stop postgres: %s", string(output)) +} + +// PostgresChaosMonkey starts a goroutine that randomly stops and starts PostgreSQL +func PostgresChaosMonkey(t *testing.T, ctx context.Context, wg *sync.WaitGroup) { + cliPath := testCLIPath + + wg.Add(1) + go func() { + defer wg.Done() + defer t.Logf("Chaos Monkey: Exiting") + + for { + // Check for context cancellation first + select { + case <-ctx.Done(): + startPostgres(t, cliPath) + return + default: + } + + // Random down time between 0 and 2 seconds + downTime := time.Duration(rand.Float64()*2) * time.Second + + // Stop PostgreSQL + require.Eventually(t, func() bool { + stopPostgres(t, cliPath) + return true + }, 5*time.Second, 100*time.Millisecond) + t.Logf("🐒 Chaos Monkey: Stopped PostgreSQL") + + // Sleep for random down time + select { + case <-time.After(downTime): + // Start PostgreSQL again + require.Eventually(t, func() bool { + startPostgres(t, cliPath) + return true + }, 5*time.Second, 100*time.Millisecond) + t.Logf("🐒 Chaos Monkey: Starting PostgreSQL") + case <-ctx.Done(): + // Ensure PostgreSQL is started before exiting + require.Eventually(t, func() bool { + startPostgres(t, cliPath) + return true + }, 5*time.Second, 100*time.Millisecond) + return + } + + // Wait a bit before next chaos event (between 5 and 40 seconds) + upTime := time.Duration(5+rand.Float64()*35) * time.Second + select { + case <-time.After(upTime): + // Continue to next iteration + case <-ctx.Done(): + t.Logf("Chaos Monkey: Context cancelled during uptime") + return + } + } + }() +} + +// setupDBOS sets up a DBOS context for integration testing +func setupDBOS(t *testing.T) dbos.DBOSContext { + t.Helper() + + databaseURL := os.Getenv("DBOS_SYSTEM_DATABASE_URL") + if databaseURL == "" { + password := os.Getenv("PGPASSWORD") + if password == "" { + password = "dbos" + } + databaseURL = fmt.Sprintf("postgres://postgres:%s@localhost:5432/dbos?sslmode=disable", url.QueryEscape(password)) + } + + // Clean up the test database + parsedURL, err := pgx.ParseConfig(databaseURL) + require.NoError(t, err) + + dbName := parsedURL.Database + postgresURL := parsedURL.Copy() + postgresURL.Database = "postgres" + conn, err := pgx.ConnectConfig(context.Background(), postgresURL) + require.NoError(t, err) + defer conn.Close(context.Background()) + + _, err = conn.Exec(context.Background(), "DROP DATABASE IF EXISTS "+dbName+" WITH (FORCE)") + require.NoError(t, err) + + dbosCtx, err := dbos.NewDBOSContext(context.Background(), dbos.Config{ + DatabaseURL: databaseURL, + AppName: "chaos-test", + Logger: slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})), + }) + require.NoError(t, err) + require.NotNil(t, dbosCtx) + + // Register cleanup to run after test completes + t.Cleanup(func() { + if dbosCtx != nil { + dbosCtx.Shutdown(30 * time.Second) + } + }) + + return dbosCtx +} + +// Test workflow with multiple steps and transactions +func TestChaosWorkflow(t *testing.T) { + dbosCtx := setupDBOS(t) + + // Start chaos monkey + var wg sync.WaitGroup + defer wg.Wait() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PostgresChaosMonkey(t, ctx, &wg) + + // Define scheduled workflow that runs every second + scheduledWorkflow := func(ctx dbos.DBOSContext, scheduledTime time.Time) (struct{}, error) { + return struct{}{}, nil + } + + // Define step functions + stepOne := func(_ context.Context, x int) (int, error) { + return x + 1, nil + } + + stepTwo := func(_ context.Context, x int) (int, error) { + return x + 2, nil + } + + // Define workflow function + workflow := func(ctx dbos.DBOSContext, x int) (int, error) { + // Execute step one + x, err := dbos.RunAsStep(ctx, func(context context.Context) (int, error) { + return stepOne(context, x) + }) + if err != nil { + return 0, fmt.Errorf("step one failed: %w", err) + } + + // Execute step two + x, err = dbos.RunAsStep(ctx, func(context context.Context) (int, error) { + return stepTwo(context, x) + }) + if err != nil { + return 0, fmt.Errorf("step two failed: %w", err) + } + + return x, nil + } + + // Register the workflows + dbos.RegisterWorkflow(dbosCtx, workflow) + // Register scheduled workflow to run every second for chaos testing + dbos.RegisterWorkflow(dbosCtx, scheduledWorkflow, dbos.WithSchedule("* * * * * *"), dbos.WithWorkflowName("ScheduledChaosTest")) + + err := dbosCtx.Launch() + require.NoError(t, err) + + // Run multiple workflows + numWorkflows := 10000 + for i := range numWorkflows { + if i%100 == 0 { + t.Logf("Starting workflow %d/%d", i+1, numWorkflows) + } + handle, err := dbos.RunWorkflow(dbosCtx, workflow, i) + require.NoError(t, err, "failed to start workflow %d", i) + + result, err := handle.GetResult() + require.NoError(t, err, "failed to get result for workflow %d", i) + assert.Equal(t, i+3, result, "unexpected result for workflow %d", i) + } + + // Validate scheduled workflow executions using ListWorkflows + scheduledWorkflows, err := dbos.ListWorkflows(dbosCtx, + dbos.WithName("ScheduledChaosTest"), + dbos.WithStatus([]dbos.WorkflowStatusType{dbos.WorkflowStatusSuccess}), + dbos.WithSortDesc(), + dbos.WithLimit(1), + dbos.WithLoadInput(false), + dbos.WithLoadOutput(false), + ) + require.NoError(t, err, "failed to list scheduled workflows") + + assert.Equal(t, len(scheduledWorkflows), 1, "Expected exactly one scheduled workflow execution") + + // Check the last execution was within 10 seconds -- reasonable for a 1 second schedule and 2 seconds postgres downtime + latestWorkflow := scheduledWorkflows[0] // Sorted descending + timeSinceLastExecution := time.Since(latestWorkflow.CreatedAt) + assert.Less(t, timeSinceLastExecution, 10*time.Second, + "Last scheduled execution was %v ago, expected less than 60 seconds", timeSinceLastExecution) +} + +// Test send/recv functionality +func TestChaosRecv(t *testing.T) { + dbosCtx := setupDBOS(t) + + // Start chaos monkey + var wg sync.WaitGroup + defer wg.Wait() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PostgresChaosMonkey(t, ctx, &wg) + + topic := "test_topic" + + // Define recv workflow + recvWorkflow := func(ctx dbos.DBOSContext, _ string) (string, error) { + // Receive from topic with timeout + value, err := dbos.Recv[string](ctx, topic, 10*time.Second) + if err != nil { + return "", fmt.Errorf("failed to receive: %w", err) + } + return value, nil + } + + // Register the workflow + dbos.RegisterWorkflow(dbosCtx, recvWorkflow) + + err := dbosCtx.Launch() + require.NoError(t, err) + + // Run multiple workflows with send/recv + numWorkflows := 10000 + for i := range numWorkflows { + if i%100 == 0 { + t.Logf("Starting workflow %d/%d", i+1, numWorkflows) + } + handle, err := dbos.RunWorkflow(dbosCtx, recvWorkflow, "") + require.NoError(t, err, "failed to start workflow %d", i) + + // Generate a random value + value := uuid.NewString() + + // Give some time to the workflow to enter the sleep state + time.Sleep((5 * time.Millisecond)) + + // Send the value to the workflow + err = dbos.Send(dbosCtx, handle.GetWorkflowID(), value, topic) + require.NoError(t, err, "failed to send value for workflow %d", i) + + // Get the result and verify it matches + result, err := handle.GetResult() + require.NoError(t, err, "failed to get result for workflow %d", i) + assert.Equal(t, value, result, "unexpected result for workflow %d", i) + } +} + +// Test event functionality +func TestChaosEvents(t *testing.T) { + dbosCtx := setupDBOS(t) + + // Start chaos monkey + var wg sync.WaitGroup + defer wg.Wait() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PostgresChaosMonkey(t, ctx, &wg) + + key := "test_key" + + // Define event workflow + eventWorkflow := func(ctx dbos.DBOSContext, _ string) (string, error) { + value := uuid.NewString() + err := dbos.SetEvent(ctx, key, value) + if err != nil { + return "", fmt.Errorf("failed to set event: %w", err) + } + return value, nil + } + + // Register the workflow + dbos.RegisterWorkflow(dbosCtx, eventWorkflow) + + err := dbosCtx.Launch() + require.NoError(t, err) + + // Run multiple workflows with events + numWorkflows := 5000 + for i := range numWorkflows { + if i%100 == 0 { + t.Logf("Starting workflow %d/%d", i+1, numWorkflows) + } + wfID := uuid.NewString() + + // Start workflow with specific ID + handle, err := dbos.RunWorkflow(dbosCtx, eventWorkflow, "", dbos.WithWorkflowID(wfID)) + require.NoError(t, err, "failed to start workflow %d", i) + + // Get the workflow result + value, err := handle.GetResult() + require.NoError(t, err, "failed to get result for workflow %d", i) + + // Retrieve the event and verify it matches + retrievedValue, err := dbos.GetEvent[string](dbosCtx, wfID, key, 0) + require.NoError(t, err, "failed to get event for workflow %d", i) + assert.Equal(t, value, retrievedValue, "unexpected event value for workflow %d", i) + } +} + +// Test queue functionality +func TestChaosQueues(t *testing.T) { + dbosCtx := setupDBOS(t) + + // Start chaos monkey + var wg sync.WaitGroup + defer wg.Wait() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PostgresChaosMonkey(t, ctx, &wg) + + queue := dbos.NewWorkflowQueue(dbosCtx, "test_queue") + + // Define step functions + stepOne := func(ctx dbos.DBOSContext, x int) (int, error) { + // Run as a step + result, err := dbos.RunAsStep(ctx, func(context context.Context) (int, error) { + return x + 1, nil + }) + if err != nil { + return 0, fmt.Errorf("step one failed: %w", err) + } + return result, nil + } + + stepTwo := func(ctx dbos.DBOSContext, x int) (int, error) { + // Run as a step + result, err := dbos.RunAsStep(ctx, func(context context.Context) (int, error) { + return x + 2, nil + }) + if err != nil { + return 0, fmt.Errorf("step two failed: %w", err) + } + return result, nil + } + + // Define main workflow that enqueues other workflows + workflow := func(ctx dbos.DBOSContext, x int) (int, error) { + // Enqueue step one + handle1, err := dbos.RunWorkflow(ctx, stepOne, x, dbos.WithQueue(queue.Name)) + if err != nil { + return 0, fmt.Errorf("failed to enqueue step one: %w", err) + } + x, err = handle1.GetResult() + if err != nil { + return 0, fmt.Errorf("failed to get result from step one: %w", err) + } + + // Enqueue step two + handle2, err := dbos.RunWorkflow(ctx, stepTwo, x, dbos.WithQueue(queue.Name)) + if err != nil { + return 0, fmt.Errorf("failed to enqueue step two: %w", err) + } + x, err = handle2.GetResult() + if err != nil { + return 0, fmt.Errorf("failed to get result from step two: %w", err) + } + return x, nil + } + + // Register all workflows + dbos.RegisterWorkflow(dbosCtx, stepOne) + dbos.RegisterWorkflow(dbosCtx, stepTwo) + dbos.RegisterWorkflow(dbosCtx, workflow) + + err := dbosCtx.Launch() + require.NoError(t, err) + + // Run multiple workflows + numWorkflows := 30 + for i := range numWorkflows { + if i%10 == 0 { + t.Logf("Starting workflow %d/%d", i+1, numWorkflows) + } + // Enqueue the main workflow + handle, err := dbos.RunWorkflow(dbosCtx, workflow, i, dbos.WithQueue(queue.Name)) + require.NoError(t, err, "failed to enqueue workflow %d", i) + + result, err := handle.GetResult() + require.NoError(t, err, "failed to get result for workflow %d", i) + assert.Equal(t, i+3, result, "unexpected result for workflow %d", i) + } +} diff --git a/cmd/dbos/README.md b/cmd/dbos/README.md index c294324c..3cdfc49d 100644 --- a/cmd/dbos/README.md +++ b/cmd/dbos/README.md @@ -6,7 +6,7 @@ The DBOS CLI is a command-line interface for managing DBOS workflows. ### From Source ```bash -go install github.com/dbos-inc/dbos-transact-golang/cmd/dbos +go install github.com/dbos-inc/dbos-transact-golang/cmd/dbos@latest ``` ### Build Locally diff --git a/cmd/dbos/postgres.go b/cmd/dbos/postgres.go index a36734d7..f07f3966 100644 --- a/cmd/dbos/postgres.go +++ b/cmd/dbos/postgres.go @@ -5,11 +5,13 @@ import ( "database/sql" "fmt" "io" + "net/url" "os" "time" "github.com/docker/docker/api/types/container" "github.com/docker/docker/api/types/image" + "github.com/docker/docker/api/types/volume" "github.com/docker/docker/client" "github.com/docker/go-connections/nat" _ "github.com/jackc/pgx/v5/stdlib" @@ -17,9 +19,10 @@ import ( ) const ( - containerName = "dbos-db" - imageName = "pgvector/pgvector:pg16" - pgData = "/var/lib/postgresql/data" + containerName = "dbos-db" + imageName = "pgvector/pgvector:pg16" + pgData = "/var/lib/postgresql/data" + hostPgDataVolumeName = "pgdata" ) var postgresCmd = &cobra.Command{ @@ -146,6 +149,10 @@ func startDockerPostgres() error { }, } + _, err = cli.VolumeCreate(ctx, volume.CreateOptions{Name: hostPgDataVolumeName}) + if err != nil { + return fmt.Errorf("failed to create volume %s for Postgres: %w", hostPgDataVolumeName, err) + } hostConfig := &container.HostConfig{ PortBindings: nat.PortMap{ "5432/tcp": []nat.PortBinding{ @@ -156,6 +163,9 @@ func startDockerPostgres() error { }, }, AutoRemove: true, + Binds: []string{ + fmt.Sprintf("%s:%s", hostPgDataVolumeName, pgData), + }, } resp, err := cli.ContainerCreate(ctx, config, hostConfig, nil, nil, containerName) @@ -174,7 +184,7 @@ func startDockerPostgres() error { return err } - logger.Info("Postgres available", "url", fmt.Sprintf("postgres://postgres:%s@localhost:5432", password)) + logger.Info("Postgres available", "url", fmt.Sprintf("postgres://postgres:%s@localhost:5432", url.QueryEscape(password))) return nil } @@ -225,7 +235,7 @@ func waitForPostgres() error { password = "dbos" } - connStr := fmt.Sprintf("postgres://postgres:%s@localhost:5432/postgres?connect_timeout=2&sslmode=disable", password) + connStr := fmt.Sprintf("postgres://postgres:%s@localhost:5432/postgres?connect_timeout=2&sslmode=disable", url.QueryEscape(password)) // Try for up to 30 seconds for i := 0; i < 30; i++ { diff --git a/dbos/client.go b/dbos/client.go index 01be7752..f87a762a 100644 --- a/dbos/client.go +++ b/dbos/client.go @@ -165,7 +165,7 @@ func (c *client) Enqueue(queueName, workflowName string, input any, opts ...Enqu tx, err := dbosCtx.systemDB.(*sysDB).pool.Begin(uncancellableCtx) if err != nil { - return nil, newWorkflowExecutionError(workflowID, fmt.Sprintf("failed to begin transaction: %v", err)) + return nil, newWorkflowExecutionError(workflowID, fmt.Errorf("failed to begin transaction: %v", err)) } defer tx.Rollback(uncancellableCtx) // Rollback if not committed diff --git a/dbos/errors.go b/dbos/errors.go index 1fb9b931..57679afa 100644 --- a/dbos/errors.go +++ b/dbos/errors.go @@ -28,7 +28,6 @@ const ( type DBOSError struct { Message string // Human-readable error message Code DBOSErrorCode // Error type code for programmatic handling - IsBase bool // Internal errors that shouldn't be caught by user code // Optional context fields - only set when relevant to the error WorkflowID string // Associated workflow identifier @@ -40,6 +39,8 @@ type DBOSError struct { ExpectedName string // Expected function name (for determinism errors) RecordedName string // Actually recorded function name (for determinism errors) MaxRetries int // Maximum retry limit (for retry-related errors) + + wrappedErr error // Underlying error being wrapped (for error unwrapping) } // Error returns a formatted error message including the error code. @@ -48,6 +49,12 @@ func (e *DBOSError) Error() string { return fmt.Sprintf("DBOS Error %d: %s", int(e.Code), e.Message) } +// Unwrap returns the underlying error, if any. +// This enables Go's error unwrapping functionality with errors.Is and errors.As. +func (e *DBOSError) Unwrap() error { + return e.wrappedErr +} + func newConflictingWorkflowError(workflowID, message string) *DBOSError { msg := fmt.Sprintf("Conflicting workflow invocation with the same ID (%s)", workflowID) if message != "" { @@ -105,7 +112,6 @@ func newWorkflowCancelledError(workflowID string) *DBOSError { return &DBOSError{ Message: fmt.Sprintf("Workflow %s was cancelled", workflowID), Code: WorkflowCancelled, - IsBase: true, } } @@ -114,7 +120,6 @@ func newWorkflowConflictIDError(workflowID string) *DBOSError { Message: fmt.Sprintf("Conflicting workflow ID %s", workflowID), Code: ConflictingIDError, WorkflowID: workflowID, - IsBase: true, } } @@ -123,7 +128,6 @@ func newWorkflowUnexpectedResultType(workflowID, expectedType, actualType string Message: fmt.Sprintf("Workflow %s returned unexpected result type: expected %s, got %s", workflowID, expectedType, actualType), Code: WorkflowUnexpectedTypeError, WorkflowID: workflowID, - IsBase: true, } } @@ -131,16 +135,15 @@ func newWorkflowUnexpectedInputType(workflowName, expectedType, actualType strin return &DBOSError{ Message: fmt.Sprintf("Workflow %s received unexpected input type: expected %s, got %s", workflowName, expectedType, actualType), Code: WorkflowUnexpectedTypeError, - IsBase: true, } } -func newWorkflowExecutionError(workflowID, message string) *DBOSError { +func newWorkflowExecutionError(workflowID string, err error) *DBOSError { return &DBOSError{ - Message: fmt.Sprintf("Workflow %s execution error: %s", workflowID, message), + Message: fmt.Sprintf("Workflow %s execution error: %s", workflowID, err.Error()), Code: WorkflowExecutionError, WorkflowID: workflowID, - IsBase: true, + wrappedErr: err, } } @@ -150,7 +153,6 @@ func newStepExecutionError(workflowID, stepName, message string) *DBOSError { Code: StepExecutionError, WorkflowID: workflowID, StepName: stepName, - IsBase: true, } } @@ -160,7 +162,6 @@ func newDeadLetterQueueError(workflowID string, maxRetries int) *DBOSError { Code: DeadLetterQueueError, WorkflowID: workflowID, MaxRetries: maxRetries, - IsBase: true, } } @@ -171,7 +172,7 @@ func newMaxStepRetriesExceededError(workflowID, stepName string, maxRetries int, WorkflowID: workflowID, StepName: stepName, MaxRetries: maxRetries, - IsBase: true, + wrappedErr: err, } } diff --git a/dbos/queue.go b/dbos/queue.go index 74b91f34..82f89b76 100644 --- a/dbos/queue.go +++ b/dbos/queue.go @@ -183,11 +183,13 @@ func (qr *queueRunner) run(ctx *dbosContext) { // Iterate through all queues in the registry for queueName, queue := range qr.workflowQueueRegistry { // Call DequeueWorkflows for each queue - dequeuedWorkflows, err := ctx.systemDB.dequeueWorkflows(ctx, dequeueWorkflowsInput{ - queue: queue, - executorID: ctx.executorID, - applicationVersion: ctx.applicationVersion, - }) + dequeuedWorkflows, err := retryWithResult(ctx, func() ([]dequeuedWorkflow, error) { + return ctx.systemDB.dequeueWorkflows(ctx, dequeueWorkflowsInput{ + queue: queue, + executorID: ctx.executorID, + applicationVersion: ctx.applicationVersion, + }) + }, withRetrierLogger(qr.logger)) if err != nil { if pgErr, ok := err.(*pgconn.PgError); ok { switch pgErr.Code { diff --git a/dbos/system_database.go b/dbos/system_database.go index aae95837..88b78ccb 100644 --- a/dbos/system_database.go +++ b/dbos/system_database.go @@ -5,15 +5,18 @@ import ( _ "embed" "errors" "fmt" + "io" "log/slog" "math" "math/rand" + "net" "net/url" "strings" "sync" "time" "github.com/google/uuid" + "github.com/jackc/pgerrcode" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgxpool" @@ -1212,7 +1215,6 @@ func (s *sysDB) recordOperationResult(ctx context.Context, input recordOperation } if err != nil { - s.logger.Error("RecordOperationResult Error occurred", "error", err) if pgErr, ok := err.(*pgconn.PgError); ok && pgErr.Code == _PG_ERROR_UNIQUE_VIOLATION { return newWorkflowConflictIDError(input.workflowID) } @@ -1541,8 +1543,6 @@ func (s *sysDB) sleep(ctx context.Context, input sleepInput) (time.Duration, err } } else { // First execution: calculate and record the end time - s.logger.Debug("Durable sleep", "stepID", stepID, "duration", input.duration) - endTime = time.Now().Add(input.duration) // Record the operation result with the calculated end time @@ -1646,7 +1646,7 @@ func (s *sysDB) notificationListenerLoop(ctx context.Context) { } // If the underlying connection is closed, attempt to re-acquire a new one if poolConn.Conn().IsClosed() { - s.logger.Error("Notification listener connection closed. re-acquiring") + s.logger.Debug("Notification listener connection closed. re-acquiring") poolConn.Release() for { if ctx.Err() != nil { @@ -1658,7 +1658,7 @@ func (s *sysDB) notificationListenerLoop(ctx context.Context) { retryAttempt = 0 break } - s.logger.Error("failed to re-acquire connection for notification listener", "error", err) + s.logger.Debug("failed to re-acquire connection for notification listener", "error", err) time.Sleep(backoffWithJitter(retryAttempt)) retryAttempt++ } @@ -1679,11 +1679,15 @@ func (s *sysDB) notificationListenerLoop(ctx context.Context) { switch n.Channel { case _DBOS_NOTIFICATIONS_CHANNEL: if cond, ok := s.workflowNotificationsMap.Load(n.Payload); ok { + cond.(*sync.Cond).L.Lock() cond.(*sync.Cond).Broadcast() + cond.(*sync.Cond).L.Unlock() } case _DBOS_WORKFLOW_EVENTS_CHANNEL: if cond, ok := s.workflowEventsMap.Load(n.Payload); ok { + cond.(*sync.Cond).L.Lock() cond.(*sync.Cond).Broadcast() + cond.(*sync.Cond).L.Unlock() } } } @@ -1828,8 +1832,10 @@ func (s *sysDB) recv(ctx context.Context, input recvInput) (any, error) { // First check if there's already a receiver for this workflow/topic to avoid unnecessary database load payload := fmt.Sprintf("%s::%s", destinationID, topic) cond := sync.NewCond(&sync.Mutex{}) + cond.L.Lock() _, loaded := s.workflowNotificationsMap.LoadOrStore(payload, cond) if loaded { + cond.L.Unlock() s.logger.Error("Receive already called for workflow", "destination_id", destinationID) return nil, newWorkflowConflictIDError(destinationID) } @@ -1845,15 +1851,12 @@ func (s *sysDB) recv(ctx context.Context, input recvInput) (any, error) { query := fmt.Sprintf(`SELECT EXISTS (SELECT 1 FROM %s.notifications WHERE destination_uuid = $1 AND topic = $2)`, pgx.Identifier{s.schema}.Sanitize()) err = s.pool.QueryRow(ctx, query, destinationID, topic).Scan(&exists) if err != nil { + cond.L.Unlock() return false, fmt.Errorf("failed to check message: %w", err) } if !exists { - // Wait for notifications using condition variable with timeout pattern - s.logger.Debug("Waiting for notification on condition variable", "payload", payload) - done := make(chan struct{}) go func() { - cond.L.Lock() defer cond.L.Unlock() cond.Wait() close(done) @@ -1870,13 +1873,14 @@ func (s *sysDB) recv(ctx context.Context, input recvInput) (any, error) { select { case <-done: - s.logger.Debug("Received notification on condition variable", "payload", payload) case <-time.After(timeout): s.logger.Warn("Recv() timeout reached", "payload", payload, "timeout", input.Timeout) case <-ctx.Done(): s.logger.Warn("Recv() context cancelled", "payload", payload, "cause", context.Cause(ctx)) return nil, ctx.Err() } + } else { + cond.L.Unlock() } // Find the oldest message and delete it atomically @@ -2056,8 +2060,10 @@ func (s *sysDB) getEvent(ctx context.Context, input getEventInput) (any, error) // Create notification payload and condition variable payload := fmt.Sprintf("%s::%s", input.TargetWorkflowID, input.Key) cond := sync.NewCond(&sync.Mutex{}) + cond.L.Lock() existingCond, loaded := s.workflowEventsMap.LoadOrStore(payload, cond) if loaded { + cond.L.Unlock() // Reuse the existing condition variable cond = existingCond.(*sync.Cond) } @@ -2065,10 +2071,8 @@ func (s *sysDB) getEvent(ctx context.Context, input getEventInput) (any, error) // Defer broadcast to ensure any waiting goroutines eventually unlock defer func() { cond.Broadcast() - // Clean up the condition variable after we're done, if we created it - if !loaded { - s.workflowEventsMap.Delete(payload) - } + // Clean up the condition variable after we're done (Delete is a no-op if the key doesn't exist) + s.workflowEventsMap.Delete(payload) }() // Check if the event already exists in the database @@ -2078,14 +2082,16 @@ func (s *sysDB) getEvent(ctx context.Context, input getEventInput) (any, error) row := s.pool.QueryRow(ctx, query, input.TargetWorkflowID, input.Key) err := row.Scan(&valueString) if err != nil && err != pgx.ErrNoRows { + if !loaded { + cond.L.Unlock() + } return nil, fmt.Errorf("failed to query workflow event: %w", err) } - if err == pgx.ErrNoRows { + if err == pgx.ErrNoRows { // this implies isLaunched is True // Wait for notification with timeout using condition variable done := make(chan struct{}) go func() { - cond.L.Lock() defer cond.L.Unlock() cond.Wait() close(done) @@ -2547,3 +2553,176 @@ func maskPassword(dbURL string) (string, error) { return parsedURL.String(), nil } + +/*******************************/ +/******* RETRIER ********/ +/*******************************/ + +func isRetryablePGError(err error, logger *slog.Logger) bool { + if err == nil { + return false + } + + // If tx is closed (because failure happened between pgx trying to commit/rollback and setting tx.closed) + // pgx will always return pgx.ErrTxClosed again. + // This is only retryable if the caller retries with a new transaction object. + // Otherwise, retrying with the same closed transaction will always fail. + if errors.Is(err, pgx.ErrTxClosed) { + if logger != nil { + logger.Warn("Transaction is closed, retrying requires a new transaction object", "error", err) + } + return true + } + + // PostgreSQL codes indicating connection/admin shutdown etc. + var pgerr *pgconn.PgError + if errors.As(err, &pgerr) { + switch pgerr.Code { + case pgerrcode.ConnectionException, + pgerrcode.ConnectionDoesNotExist, + pgerrcode.ConnectionFailure, + pgerrcode.SQLClientUnableToEstablishSQLConnection, + pgerrcode.SQLServerRejectedEstablishmentOfSQLConnection, + pgerrcode.AdminShutdown, + pgerrcode.CrashShutdown, + pgerrcode.CannotConnectNow: + return true + } + } + + // pgx aggregate for connect attempts: + var cerr *pgconn.ConnectError + if errors.As(err, &cerr) { + return true + } + + // Match most "connection closed" cases + if errors.Is(err, io.EOF) || strings.Contains(err.Error(), "conn closed") { + return true + } + + // Net-level errors + var nerr net.Error + return errors.As(err, &nerr) +} + +// retryConfig holds the configuration for a retry operation +type retryConfig struct { + maxRetries int // -1 for infinite retries + baseDelay time.Duration + maxDelay time.Duration + backoffFactor float64 + jitterMin float64 + jitterMax float64 + retryCondition func(error, *slog.Logger) bool + logger *slog.Logger +} + +// retryOption is a functional option for configuring retry behavior +type retryOption func(*retryConfig) + +// withRetrierLogger sets the logger for the retrier +func withRetrierLogger(logger *slog.Logger) retryOption { + return func(c *retryConfig) { + c.logger = logger + } +} + +// retry executes a function with retry logic using functional options +func retry(ctx context.Context, fn func() error, options ...retryOption) error { + // Start with default configuration + config := &retryConfig{ + maxRetries: -1, + baseDelay: 100 * time.Millisecond, + maxDelay: 30 * time.Second, + backoffFactor: 2.0, + jitterMin: 0.95, + jitterMax: 1.05, + retryCondition: isRetryablePGError, + } + + // Apply options + for _, opt := range options { + opt(config) + } + + var lastErr error + delay := config.baseDelay + attempt := 0 + + for { + lastErr = fn() + + // Success and rollback case + if lastErr == nil { + return nil + } + + // Check if error is retryable + if !config.retryCondition(lastErr, config.logger) { + if config.logger != nil { + config.logger.Debug("Non-retryable error encountered", "error", lastErr) + } + return lastErr + } + + // Check if we should continue retrying + // If maxRetries is -1, retry indefinitely + if config.maxRetries >= 0 && attempt >= config.maxRetries { + return lastErr + } + + // Log retry attempt if logger is provided + if config.logger != nil { + config.logger.Debug("Retrying operation", + "attempt", attempt+1, + "max_retries", config.maxRetries, + "delay", delay, + "error", lastErr) + } + + // Apply jitter to the delay + jitterRange := config.jitterMax - config.jitterMin + jitterFactor := config.jitterMin + rand.Float64()*jitterRange // #nosec G404 -- trivial use of math/rand + jitteredDelay := time.Duration(float64(delay) * jitterFactor) + + // Wait before retrying with context cancellation support + select { + case <-time.After(jitteredDelay): + case <-ctx.Done(): + if config.logger != nil { + config.logger.Debug("Retry operation cancelled", "error", ctx.Err()) + } + return ctx.Err() + } + + // Calculate next delay with exponential backoff + delay = min(time.Duration(float64(delay)*config.backoffFactor), config.maxDelay) + + attempt++ + } +} + +// retryWithResult executes a function that returns a value with retry logic +// It uses the non-generic retry function under the hood +func retryWithResult[T any](ctx context.Context, fn func() (T, error), options ...retryOption) (T, error) { + var result T + var capturedErr error + + // Wrap the generic function to work with the non-generic retry + wrappedFn := func() error { + var err error + result, err = fn() + capturedErr = err + return err + } + + // Use the non-generic retry function + err := retry(ctx, wrappedFn, options...) + + // Return the last result and error + if err != nil { + return result, capturedErr + } + return result, nil +} diff --git a/dbos/workflow.go b/dbos/workflow.go index ed153a07..475d66fd 100644 --- a/dbos/workflow.go +++ b/dbos/workflow.go @@ -108,17 +108,21 @@ func (h *baseWorkflowHandle) GetStatus() (WorkflowStatus, error) { var err error if isWithinWorkflow { workflowStatuses, err = RunAsStep(c, func(ctx context.Context) ([]WorkflowStatus, error) { - return c.systemDB.listWorkflows(ctx, listWorkflowsDBInput{ + return retryWithResult(ctx, func() ([]WorkflowStatus, error) { + return c.systemDB.listWorkflows(ctx, listWorkflowsDBInput{ + workflowIDs: []string{h.workflowID}, + loadInput: loadInput, + loadOutput: loadOutput, + }) + }, withRetrierLogger(c.logger)) + }, WithStepName("DBOS.getStatus")) + } else { + workflowStatuses, err = retryWithResult(c, func() ([]WorkflowStatus, error) { + return c.systemDB.listWorkflows(c, listWorkflowsDBInput{ workflowIDs: []string{h.workflowID}, loadInput: loadInput, loadOutput: loadOutput, }) - }, WithStepName("DBOS.getStatus")) - } else { - workflowStatuses, err = c.systemDB.listWorkflows(c, listWorkflowsDBInput{ - workflowIDs: []string{h.workflowID}, - loadInput: loadInput, - loadOutput: loadOutput, }) } if err != nil { @@ -170,7 +174,7 @@ func (h *workflowHandle[R]) GetResult() (R, error) { if isWithinWorkflow { encodedOutput, encErr := serialize(outcome.result) if encErr != nil { - return *new(R), newWorkflowExecutionError(workflowState.workflowID, fmt.Sprintf("serializing child workflow result: %v", encErr)) + return *new(R), newWorkflowExecutionError(workflowState.workflowID, fmt.Errorf("serializing child workflow result: %w", encErr)) } recordGetResultInput := recordChildGetResultDBInput{ parentWorkflowID: workflowState.workflowID, @@ -179,10 +183,12 @@ func (h *workflowHandle[R]) GetResult() (R, error) { output: encodedOutput, err: outcome.err, } - recordResultErr := h.dbosContext.(*dbosContext).systemDB.recordChildGetResult(h.dbosContext, recordGetResultInput) + recordResultErr := retry(h.dbosContext, func() error { + return h.dbosContext.(*dbosContext).systemDB.recordChildGetResult(h.dbosContext, recordGetResultInput) + }, withRetrierLogger(h.dbosContext.(*dbosContext).logger)) if recordResultErr != nil { h.dbosContext.(*dbosContext).logger.Error("failed to record get result", "error", recordResultErr) - return *new(R), newWorkflowExecutionError(workflowState.workflowID, fmt.Sprintf("recording child workflow result: %v", recordResultErr)) + return *new(R), newWorkflowExecutionError(workflowState.workflowID, fmt.Errorf("recording child workflow result: %w", recordResultErr)) } } return outcome.result, outcome.err @@ -193,7 +199,9 @@ type workflowPollingHandle[R any] struct { } func (h *workflowPollingHandle[R]) GetResult() (R, error) { - result, err := h.dbosContext.(*dbosContext).systemDB.awaitWorkflowResult(h.dbosContext, h.workflowID) + result, err := retryWithResult(h.dbosContext, func() (any, error) { + return h.dbosContext.(*dbosContext).systemDB.awaitWorkflowResult(h.dbosContext, h.workflowID) + }, withRetrierLogger(h.dbosContext.(*dbosContext).logger)) if result != nil { typedResult, ok := result.(R) if !ok { @@ -205,7 +213,7 @@ func (h *workflowPollingHandle[R]) GetResult() (R, error) { if isWithinWorkflow { encodedOutput, encErr := serialize(typedResult) if encErr != nil { - return *new(R), newWorkflowExecutionError(workflowState.workflowID, fmt.Sprintf("serializing child workflow result: %v", encErr)) + return *new(R), newWorkflowExecutionError(workflowState.workflowID, fmt.Errorf("serializing child workflow result: %w", encErr)) } recordGetResultInput := recordChildGetResultDBInput{ parentWorkflowID: workflowState.workflowID, @@ -214,10 +222,12 @@ func (h *workflowPollingHandle[R]) GetResult() (R, error) { output: encodedOutput, err: err, } - recordResultErr := h.dbosContext.(*dbosContext).systemDB.recordChildGetResult(h.dbosContext, recordGetResultInput) + recordResultErr := retry(h.dbosContext, func() error { + return h.dbosContext.(*dbosContext).systemDB.recordChildGetResult(h.dbosContext, recordGetResultInput) + }, withRetrierLogger(h.dbosContext.(*dbosContext).logger)) if recordResultErr != nil { h.dbosContext.(*dbosContext).logger.Error("failed to record get result", "error", recordResultErr) - return *new(R), newWorkflowExecutionError(workflowState.workflowID, fmt.Sprintf("recording child workflow result: %v", recordResultErr)) + return *new(R), newWorkflowExecutionError(workflowState.workflowID, fmt.Errorf("recording child workflow result: %w", recordResultErr)) } } return typedResult, err @@ -437,18 +447,6 @@ func RegisterWorkflow[P any, R any](ctx DBOSContext, fn Workflow[P, R], opts ... typedErasedWorkflow := WorkflowFunc(func(ctx DBOSContext, input any) (any, error) { typedInput, ok := input.(P) if !ok { - wfID, err := ctx.GetWorkflowID() - if err != nil { - return nil, fmt.Errorf("failed to get workflow ID: %w", err) - } - err = ctx.(*dbosContext).systemDB.updateWorkflowOutcome(WithoutCancel(ctx), updateWorkflowOutcomeDBInput{ - workflowID: wfID, - status: WorkflowStatusError, - err: newWorkflowUnexpectedInputType(fqn, fmt.Sprintf("%T", typedInput), fmt.Sprintf("%T", input)), - }) - if err != nil { - return nil, fmt.Errorf("failed to record unexpected input type error: %w", err) - } return nil, newWorkflowUnexpectedInputType(fqn, fmt.Sprintf("%T", typedInput), fmt.Sprintf("%T", input)) } return fn(ctx, typedInput) @@ -460,7 +458,7 @@ func RegisterWorkflow[P any, R any](ctx DBOSContext, fn Workflow[P, R], opts ... if err != nil { return nil, err } - return newWorkflowPollingHandle[any](ctx, handle.GetWorkflowID()), nil // this is only used by recovery and queue runner so far -- queue runner dismisses it + return newWorkflowPollingHandle[any](ctx, handle.GetWorkflowID()), nil // this is only used by recovery -- the queue runner dismisses it }) registerWorkflow(ctx, fqn, typeErasedWrapper, registrationParams.maxRetries, registrationParams.name) @@ -683,7 +681,7 @@ func (c *dbosContext) RunWorkflow(_ DBOSContext, fn WorkflowFunc, input any, opt if isChildWorkflow { childWorkflowID, err := c.systemDB.checkChildWorkflow(uncancellableCtx, parentWorkflowState.workflowID, parentWorkflowState.stepID) if err != nil { - return nil, newWorkflowExecutionError(parentWorkflowState.workflowID, fmt.Sprintf("checking child workflow: %v", err)) + return nil, newWorkflowExecutionError(parentWorkflowState.workflowID, fmt.Errorf("checking child workflow: %w", err)) } if childWorkflowID != nil { return newWorkflowPollingHandle[any](uncancellableCtx, *childWorkflowID), nil @@ -734,54 +732,70 @@ func (c *dbosContext) RunWorkflow(_ DBOSContext, fn WorkflowFunc, input any, opt Priority: int(params.priority), } - // Init status and record child workflow relationship in a single transaction - tx, err := c.systemDB.(*sysDB).pool.Begin(uncancellableCtx) - if err != nil { - return nil, newWorkflowExecutionError(workflowID, fmt.Sprintf("failed to begin transaction: %v", err)) - } - defer tx.Rollback(uncancellableCtx) // Rollback if not committed + var earlyReturnPollingHandle *workflowPollingHandle[any] + var insertStatusResult *insertWorkflowResult - // Insert workflow status with transaction - insertInput := insertWorkflowStatusDBInput{ - status: workflowStatus, - maxRetries: params.maxRetries, - tx: tx, - } - insertStatusResult, err := c.systemDB.insertWorkflowStatus(uncancellableCtx, insertInput) - if err != nil { - c.logger.Error("failed to insert workflow status", "error", err, "workflow_id", workflowID) - return nil, err - } + // Init status and record child workflow relationship in a single transaction + err := retry(c, func() error { + tx, err := c.systemDB.(*sysDB).pool.Begin(uncancellableCtx) + if err != nil { + return newWorkflowExecutionError(workflowID, fmt.Errorf("failed to begin transaction: %w", err)) + } + defer tx.Rollback(uncancellableCtx) // Rollback if not committed - // Record child workflow relationship if this is a child workflow - if isChildWorkflow { - // Get the step ID that was used for generating the child workflow ID - childInput := recordChildWorkflowDBInput{ - parentWorkflowID: parentWorkflowState.workflowID, - childWorkflowID: workflowID, - stepName: params.workflowName, - stepID: parentWorkflowState.stepID, - tx: tx, + // Insert workflow status with transaction + insertInput := insertWorkflowStatusDBInput{ + status: workflowStatus, + maxRetries: params.maxRetries, + tx: tx, } - err = c.systemDB.recordChildWorkflow(uncancellableCtx, childInput) + insertStatusResult, err = c.systemDB.insertWorkflowStatus(uncancellableCtx, insertInput) if err != nil { - c.logger.Error("failed to record child workflow", "error", err, "parent_workflow_id", parentWorkflowState.workflowID, "child_workflow_id", workflowID) - return nil, newWorkflowExecutionError(parentWorkflowState.workflowID, fmt.Sprintf("recording child workflow: %v", err)) + c.logger.Error("failed to insert workflow status", "error", err, "workflow_id", workflowID) + return err + } + + // Record child workflow relationship if this is a child workflow + if isChildWorkflow { + // Get the step ID that was used for generating the child workflow ID + childInput := recordChildWorkflowDBInput{ + parentWorkflowID: parentWorkflowState.workflowID, + childWorkflowID: workflowID, + stepName: params.workflowName, + stepID: parentWorkflowState.stepID, + tx: tx, + } + err = c.systemDB.recordChildWorkflow(uncancellableCtx, childInput) + if err != nil { + c.logger.Error("failed to record child workflow", "error", err, "parent_workflow_id", parentWorkflowState.workflowID, "child_workflow_id", workflowID) + return newWorkflowExecutionError(parentWorkflowState.workflowID, fmt.Errorf("recording child workflow: %w", err)) + } } - } - // Return a polling handle if: we are enqueueing, the workflow is already in a terminal state (success or error), - if len(params.queueName) > 0 || insertStatusResult.status == WorkflowStatusSuccess || insertStatusResult.status == WorkflowStatusError { - // Commit the transaction to update the number of attempts and/or enact the enqueue + // Return a polling handle if: we are enqueueing, the workflow is already in a terminal state (success or error), + if len(params.queueName) > 0 || insertStatusResult.status == WorkflowStatusSuccess || insertStatusResult.status == WorkflowStatusError { + // Commit the transaction to update the number of attempts and/or enact the enqueue + if err := tx.Commit(uncancellableCtx); err != nil { + return newWorkflowExecutionError(workflowID, fmt.Errorf("failed to commit transaction: %w", err)) + } + earlyReturnPollingHandle = newWorkflowPollingHandle[any](uncancellableCtx, workflowStatus.ID) + return nil + } + + // Commit the transaction. This must happen before we start the goroutine to ensure the workflow is found by steps in the database if err := tx.Commit(uncancellableCtx); err != nil { - return nil, newWorkflowExecutionError(workflowID, fmt.Sprintf("failed to commit transaction: %v", err)) + return newWorkflowExecutionError(workflowID, fmt.Errorf("failed to commit transaction: %w", err)) } - return newWorkflowPollingHandle[any](uncancellableCtx, workflowStatus.ID), nil + + return nil + }, withRetrierLogger(c.logger)) + if err != nil { + return nil, err + } + if earlyReturnPollingHandle != nil { + return earlyReturnPollingHandle, nil } - // Channel to receive the outcome from the goroutine - // The buffer size of 1 allows the goroutine to send the outcome without blocking - // In addition it allows the channel to be garbage collected outcomeChan := make(chan workflowOutcome[any], 1) // Create workflow state to track step execution @@ -808,7 +822,9 @@ func (c *dbosContext) RunWorkflow(_ DBOSContext, fn WorkflowFunc, input any, opt // Register a cancel function that cancels the workflow in the DB as soon as the context is cancelled dbosCancelFunction := func() { c.logger.Info("Cancelling workflow", "workflow_id", workflowID) - err = c.systemDB.cancelWorkflow(uncancellableCtx, workflowID) + err = retry(c, func() error { + return c.systemDB.cancelWorkflow(uncancellableCtx, workflowID) + }, withRetrierLogger(c.logger)) if err != nil { c.logger.Error("Failed to cancel workflow", "error", err) } @@ -817,11 +833,6 @@ func (c *dbosContext) RunWorkflow(_ DBOSContext, fn WorkflowFunc, input any, opt stopFunc = context.AfterFunc(workflowCtx, dbosCancelFunction) } - // Commit the transaction. This must happen before we start the goroutine to ensure the workflow is found by steps in the database - if err := tx.Commit(uncancellableCtx); err != nil { - return nil, newWorkflowExecutionError(workflowID, fmt.Sprintf("failed to commit transaction: %v", err)) - } - // Run the function in a goroutine c.workflowsWg.Add(1) go func() { @@ -836,7 +847,9 @@ func (c *dbosContext) RunWorkflow(_ DBOSContext, fn WorkflowFunc, input any, opt var dbosErr *DBOSError if errors.As(err, &dbosErr) && dbosErr.Code == ConflictingIDError { c.logger.Warn("Workflow ID conflict detected. Waiting for existing workflow to complete", "workflow_id", workflowID) - result, err = c.systemDB.awaitWorkflowResult(uncancellableCtx, workflowID) + result, err = retryWithResult(c, func() (any, error) { + return c.systemDB.awaitWorkflowResult(uncancellableCtx, workflowID) + }, withRetrierLogger(c.logger)) } else { status := WorkflowStatusSuccess @@ -852,12 +865,14 @@ func (c *dbosContext) RunWorkflow(_ DBOSContext, fn WorkflowFunc, input any, opt status = WorkflowStatusCancelled } - recordErr := c.systemDB.updateWorkflowOutcome(uncancellableCtx, updateWorkflowOutcomeDBInput{ - workflowID: workflowID, - status: status, - err: err, - output: result, - }) + recordErr := retry(c, func() error { + return c.systemDB.updateWorkflowOutcome(uncancellableCtx, updateWorkflowOutcomeDBInput{ + workflowID: workflowID, + status: status, + err: err, + output: result, + }) + }, withRetrierLogger(c.logger)) if recordErr != nil { c.logger.Error("Error recording workflow outcome", "workflow_id", workflowID, "error", recordErr) outcomeChan <- workflowOutcome[any]{result: nil, err: recordErr} @@ -1054,11 +1069,13 @@ func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc, opts ...StepOption) uncancellableCtx := WithoutCancel(c) // Check the step is cancelled, has already completed, or is called with a different name - recordedOutput, err := c.systemDB.checkOperationExecution(uncancellableCtx, checkOperationExecutionDBInput{ - workflowID: stepState.workflowID, - stepID: stepState.stepID, - stepName: stepOpts.stepName, - }) + recordedOutput, err := retryWithResult(c, func() (*recordedResult, error) { + return c.systemDB.checkOperationExecution(uncancellableCtx, checkOperationExecutionDBInput{ + workflowID: stepState.workflowID, + stepID: stepState.stepID, + stepName: stepOpts.stepName, + }) + }, withRetrierLogger(c.logger)) if err != nil { return nil, newStepExecutionError(stepState.workflowID, stepOpts.stepName, fmt.Sprintf("checking operation execution: %v", err)) } @@ -1121,7 +1138,9 @@ func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc, opts ...StepOption) err: stepError, output: stepOutput, } - recErr := c.systemDB.recordOperationResult(uncancellableCtx, dbInput) + recErr := retry(c, func() error { + return c.systemDB.recordOperationResult(uncancellableCtx, dbInput) + }, withRetrierLogger(c.logger)) if recErr != nil { return nil, newStepExecutionError(stepState.workflowID, stepOpts.stepName, fmt.Sprintf("recording step outcome: %v", recErr)) } @@ -1134,11 +1153,13 @@ func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc, opts ...StepOption) /****************************************/ func (c *dbosContext) Send(_ DBOSContext, destinationID string, message any, topic string) error { - return c.systemDB.send(c, WorkflowSendInput{ - DestinationID: destinationID, - Message: message, - Topic: topic, - }) + return retry(c, func() error { + return c.systemDB.send(c, WorkflowSendInput{ + DestinationID: destinationID, + Message: message, + Topic: topic, + }) + }, withRetrierLogger(c.logger)) } // Send sends a message to another workflow with type safety. @@ -1169,7 +1190,9 @@ func (c *dbosContext) Recv(_ DBOSContext, topic string, timeout time.Duration) ( Topic: topic, Timeout: timeout, } - return c.systemDB.recv(c, input) + return retryWithResult(c, func() (any, error) { + return c.systemDB.recv(c, input) + }, withRetrierLogger(c.logger)) } // Recv receives a message sent to this workflow with type safety. @@ -1207,10 +1230,12 @@ func Recv[R any](ctx DBOSContext, topic string, timeout time.Duration) (R, error } func (c *dbosContext) SetEvent(_ DBOSContext, key string, message any) error { - return c.systemDB.setEvent(c, WorkflowSetEventInput{ - Key: key, - Message: message, - }) + return retry(c, func() error { + return c.systemDB.setEvent(c, WorkflowSetEventInput{ + Key: key, + Message: message, + }) + }, withRetrierLogger(c.logger)) } // SetEvent sets a key-value event for the current workflow with type safety. @@ -1244,7 +1269,9 @@ func (c *dbosContext) GetEvent(_ DBOSContext, targetWorkflowID, key string, time Key: key, Timeout: timeout, } - return c.systemDB.getEvent(c, input) + return retryWithResult(c, func() (any, error) { + return c.systemDB.getEvent(c, input) + }, withRetrierLogger(c.logger)) } // GetEvent retrieves a key-value event from a target workflow with type safety. @@ -1281,7 +1308,9 @@ func GetEvent[R any](ctx DBOSContext, targetWorkflowID, key string, timeout time } func (c *dbosContext) Sleep(_ DBOSContext, duration time.Duration) (time.Duration, error) { - return c.systemDB.sleep(c, sleepInput{duration: duration, skipSleep: false}) + return retryWithResult(c, func() (time.Duration, error) { + return c.systemDB.sleep(c, sleepInput{duration: duration, skipSleep: false}) + }, withRetrierLogger(c.logger)) } // Sleep pauses workflow execution for the specified duration. @@ -1795,20 +1824,13 @@ func (c *dbosContext) ListWorkflows(_ DBOSContext, opts ...ListWorkflowsOption) // Call the context method to list workflows workflowState, ok := c.Value(workflowStateKey).(*workflowState) isWithinWorkflow := ok && workflowState != nil - var workflows []WorkflowStatus - var err error if isWithinWorkflow { - workflows, err = RunAsStep(c, func(ctx context.Context) ([]WorkflowStatus, error) { + return RunAsStep(c, func(ctx context.Context) ([]WorkflowStatus, error) { return c.systemDB.listWorkflows(ctx, dbInput) }, WithStepName("DBOS.listWorkflows")) } else { - workflows, err = c.systemDB.listWorkflows(c, dbInput) - } - if err != nil { - return nil, err + return c.systemDB.listWorkflows(c, dbInput) } - - return workflows, nil } // ListWorkflows retrieves a list of workflows based on the provided filters. diff --git a/integration/mocks_test.go b/integration/mocks_test.go index 53be1cb7..520bca3f 100644 --- a/integration/mocks_test.go +++ b/integration/mocks_test.go @@ -190,9 +190,4 @@ func TestMocks(t *testing.T) { if err != nil { t.Fatal(err) } - - mockCtx.AssertExpectations(t) - mockChildHandle.AssertExpectations(t) - mockGenericHandle.AssertExpectations(t) - fmt.Println("TestMocks completed successfully") }