Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -173,4 +173,4 @@ Install the DBOS Transact package in your program:
github.com/dbos-inc/dbos-transact-go
```

You can store and export a Postgres connection string in the `DBOS_DATABASE_URL` environment variable for DBOS to manage your workflows state. By default, DBOS will use `postgres://postgres:${PGPASSWORD}@localhost:5432/dbos?sslmode=disable`.
You can store and export a Postgres connection string in the `DBOS_SYSTEM_DATABASE_URL` environment variable for DBOS to manage your workflows state. By default, DBOS will use `postgres://postgres:${PGPASSWORD}@localhost:5432/dbos?sslmode=disable`.
4 changes: 2 additions & 2 deletions dbos/admin-server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ import (

func TestAdminServer(t *testing.T) {
// Skip if database is not available
databaseURL := os.Getenv("DBOS_DATABASE_URL")
databaseURL := os.Getenv("DBOS_SYSTEM_DATABASE_URL")
if databaseURL == "" && os.Getenv("PGPASSWORD") == "" {
t.Skip("Database not available (DBOS_DATABASE_URL and PGPASSWORD not set), skipping DBOS integration tests")
t.Skip("Database not available (DBOS_SYSTEM_DATABASE_URL and PGPASSWORD not set), skipping DBOS integration tests")
}

t.Run("Admin server is not started without WithAdminServer option", func(t *testing.T) {
Expand Down
52 changes: 45 additions & 7 deletions dbos/dbos.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/hex"
"fmt"
"log/slog"
"net/url"
"os"
"reflect"
"runtime"
Expand Down Expand Up @@ -78,12 +79,7 @@ func getExecutor() *executor {
var logger *slog.Logger

func getLogger() *slog.Logger {
if dbos == nil {
fmt.Println("warning: DBOS instance not initialized, using default logger")
return slog.New(slog.NewTextHandler(os.Stderr, nil))
}
if logger == nil {
fmt.Println("warning: DBOS logger is nil, using default logger")
if dbos == nil || logger == nil {
return slog.New(slog.NewTextHandler(os.Stderr, nil))
}
return logger
Expand All @@ -92,6 +88,40 @@ func getLogger() *slog.Logger {
type config struct {
logger *slog.Logger
adminServer bool
databaseURL string
appName string
}

// NewConfig merges configuration from two sources in order of precedence:
// 1. programmatic configuration
// 2. environment variables
// Finally, it applies default values if needed.
func NewConfig(programmaticConfig config) *config {
dbosConfig := &config{}

// Start with environment variables (lowest precedence)
if dbURL := os.Getenv("DBOS_SYSTEM_DATABASE_URL"); dbURL != "" {
dbosConfig.databaseURL = dbURL
}

// Override with programmatic configuration (highest precedence)
if len(programmaticConfig.databaseURL) > 0 {
dbosConfig.databaseURL = programmaticConfig.databaseURL
}
if len(programmaticConfig.appName) > 0 {
dbosConfig.appName = programmaticConfig.appName
}
// Copy over parameters that can only be set programmatically
dbosConfig.logger = programmaticConfig.logger
dbosConfig.adminServer = programmaticConfig.adminServer

// Load defaults
if len(dbosConfig.databaseURL) == 0 {
getLogger().Info("Using default database URL: postgres://postgres:${PGPASSWORD}@localhost:5432/dbos?sslmode=disable")
password := url.QueryEscape(os.Getenv("PGPASSWORD"))
dbosConfig.databaseURL = fmt.Sprintf("postgres://postgres:%s@localhost:5432/dbos?sslmode=disable", password)
}
return dbosConfig
}

type LaunchOption func(*config)
Expand All @@ -108,18 +138,26 @@ func WithAdminServer() LaunchOption {
}
}

func WithDatabaseURL(url string) LaunchOption {
return func(config *config) {
config.databaseURL = url
}
}

func Launch(options ...LaunchOption) error {
if dbos != nil {
fmt.Println("warning: DBOS instance already initialized, skipping re-initialization")
return NewInitializationError("DBOS already initialized")
}

// Load & process the configuration
config := &config{
logger: slog.New(slog.NewTextHandler(os.Stderr, nil)),
}
for _, option := range options {
option(config)
}
config = NewConfig(*config)

logger = config.logger

Expand All @@ -139,7 +177,7 @@ func Launch(options ...LaunchOption) error {
APP_ID = os.Getenv("DBOS__APPID")

// Create the system database
systemDB, err := NewSystemDatabase()
systemDB, err := NewSystemDatabase(config.databaseURL)
if err != nil {
return NewInitializationError(fmt.Sprintf("failed to create system database: %v", err))
}
Expand Down
23 changes: 12 additions & 11 deletions dbos/queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,13 @@ import (
)

var (
workflowQueueRegistry = make(map[string]WorkflowQueue)
DBOS_INTERNAL_QUEUE_NAME = "_dbos_internal_queue"
_ = NewWorkflowQueue(DBOS_INTERNAL_QUEUE_NAME)
workflowQueueRegistry = make(map[string]WorkflowQueue)
_ = NewWorkflowQueue(_DBOS_INTERNAL_QUEUE_NAME)
)

const (
_DBOS_INTERNAL_QUEUE_NAME = "_dbos_internal_queue"
_DEFAULT_MAX_TASKS_PER_ITERATION = 100
)

// RateLimiter represents a rate limiting configuration
Expand All @@ -31,7 +35,7 @@ type WorkflowQueue struct {
GlobalConcurrency *int
PriorityEnabled bool
Limiter *RateLimiter
MaxTasksPerIteration uint
MaxTasksPerIteration int
}

// QueueOption is a functional option for configuring a workflow queue
Expand Down Expand Up @@ -61,7 +65,7 @@ func WithRateLimiter(limiter *RateLimiter) QueueOption {
}
}

func WithMaxTasksPerIteration(maxTasks uint) QueueOption {
func WithMaxTasksPerIteration(maxTasks int) QueueOption {
return func(q *WorkflowQueue) {
q.MaxTasksPerIteration = maxTasks
}
Expand All @@ -84,7 +88,7 @@ func NewWorkflowQueue(name string, options ...QueueOption) WorkflowQueue {
GlobalConcurrency: nil,
PriorityEnabled: false,
Limiter: nil,
MaxTasksPerIteration: 100, // Default max tasks per iteration
MaxTasksPerIteration: _DEFAULT_MAX_TASKS_PER_ITERATION,
}

// Apply functional options
Expand All @@ -111,17 +115,14 @@ func queueRunner(ctx context.Context) {

pollingInterval := baseInterval

// XXX doing this lets the dequeue and the task invokation survive the context cancellation
// We might be OK with not doing this. During the tests it results in all sorts of error inside the two functions above due to context cancellation
runnerContext := context.WithoutCancel(ctx)
Comment on lines -114 to -116
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This causes a small race condition in the TestWorkerConcurrency test -- I might have to change this test altogether eventually

for {
hasBackoffError := false

// Iterate through all queues in the registry
for queueName, queue := range workflowQueueRegistry {
getLogger().Debug("Processing queue", "queue_name", queueName)
// Call DequeueWorkflows for each queue
dequeuedWorkflows, err := getExecutor().systemDB.DequeueWorkflows(runnerContext, queue)
dequeuedWorkflows, err := getExecutor().systemDB.DequeueWorkflows(ctx, queue)
if err != nil {
if pgErr, ok := err.(*pgconn.PgError); ok {
switch pgErr.Code {
Expand Down Expand Up @@ -164,7 +165,7 @@ func queueRunner(ctx context.Context) {
}
}

_, err := registeredWorkflow.wrappedFunction(runnerContext, input, WithWorkflowID(workflow.id))
_, err := registeredWorkflow.wrappedFunction(ctx, input, WithWorkflowID(workflow.id))
if err != nil {
getLogger().Error("Error recovering workflow", "error", err)
}
Expand Down
13 changes: 13 additions & 0 deletions dbos/queues_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,7 @@ func TestGlobalConcurrency(t *testing.T) {

// Wait for the first workflow to start
workflowEvent1.Wait()
time.Sleep(2 * time.Second) // Wait for a few seconds to let the queue runner loop

// Ensure the second workflow has not started yet
if workflowEvent2.IsSet {
Expand Down Expand Up @@ -504,8 +505,12 @@ func TestWorkerConcurrency(t *testing.T) {
t.Fatalf("expected 3 workflows to be enqueued, got %d", len(workflows))
}

// Stop the queue runner before changing executor ID to avoid race conditions
stopQueueRunner()
// Change the EXECUTOR_ID global variable to a different value
EXECUTOR_ID = "worker-2"
// Restart the queue runner
restartQueueRunner()

// Wait for the second workflow to start on the second worker
startEvents[1].Wait()
Expand Down Expand Up @@ -533,8 +538,12 @@ func TestWorkerConcurrency(t *testing.T) {
if result1 != 0 {
t.Fatalf("expected result from blocking workflow 1 to be 0, got %v", result1)
}
// Stop the queue runner before changing executor ID to avoid race conditions
stopQueueRunner()
// Change the executor again and wait for the third workflow to start
EXECUTOR_ID = "local"
// Restart the queue runner
restartQueueRunner()
startEvents[2].Wait()
// Ensure the fourth workflow is not started yet
if startEvents[3].IsSet {
Expand All @@ -561,8 +570,12 @@ func TestWorkerConcurrency(t *testing.T) {
if result2 != 1 {
t.Fatalf("expected result from blocking workflow 2 to be 1, got %v", result2)
}
// Stop the queue runner before changing executor ID to avoid race conditions
stopQueueRunner()
// change executor again and wait for the fourth workflow to start
EXECUTOR_ID = "worker-2"
// Restart the queue runner
restartQueueRunner()
startEvents[3].Wait()
// Check no workflow is enqueued
workflows, err = getExecutor().systemDB.ListWorkflows(context.Background(), ListWorkflowsDBInput{
Expand Down
34 changes: 12 additions & 22 deletions dbos/system_database.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ import (
"embed"
"errors"
"fmt"
"net/url"
"os"
"strings"
"sync"
"time"
Expand Down Expand Up @@ -89,6 +87,7 @@ func createDatabaseIfNotExists(databaseURL string) error {
if err != nil {
return NewInitializationError(fmt.Sprintf("failed to create database %s: %v", dbName, err))
}
getLogger().Info("Database created", "name", dbName)
}

return nil
Expand Down Expand Up @@ -125,36 +124,28 @@ func runMigrations(databaseURL string) error {
}

// New creates a new SystemDatabase instance and runs migrations
func NewSystemDatabase() (SystemDatabase, error) {
// TODO: pass proper config
databaseURL := os.Getenv("DBOS_DATABASE_URL")
if databaseURL == "" {
fmt.Println("DBOS_DATABASE_URL not set, using default: postgres://postgres:${PGPASSWORD}@localhost:5432/dbos?sslmode=disable")
password := url.QueryEscape(os.Getenv("PGPASSWORD"))
databaseURL = fmt.Sprintf("postgres://postgres:%s@localhost:5432/dbos?sslmode=disable", password)
}

func NewSystemDatabase(databaseURL string) (SystemDatabase, error) {
// Create the database if it doesn't exist
if err := createDatabaseIfNotExists(databaseURL); err != nil {
return nil, NewInitializationError(fmt.Sprintf("failed to create database: %v", err))
return nil, fmt.Errorf("failed to create database: %v", err)
}

// Run migrations first
if err := runMigrations(databaseURL); err != nil {
return nil, NewInitializationError(fmt.Sprintf("failed to run migrations: %v", err))
return nil, fmt.Errorf("failed to run migrations: %v", err)
}

// Create pgx pool
pool, err := pgxpool.New(context.Background(), databaseURL)
if err != nil {
return nil, NewInitializationError(fmt.Sprintf("failed to create connection pool: %v", err))
return nil, fmt.Errorf("failed to create connection pool: %v", err)
}

// Test the connection
// FIXME: remove this
if err := pool.Ping(context.Background()); err != nil {
pool.Close()
return nil, NewInitializationError(fmt.Sprintf("failed to ping database: %v", err))
return nil, fmt.Errorf("failed to ping database: %v", err)
}

// Create a map of notification payloads to channels
Expand All @@ -163,7 +154,7 @@ func NewSystemDatabase() (SystemDatabase, error) {
// Create a connection to listen on notifications
config, err := pgconn.ParseConfig(databaseURL)
if err != nil {
return nil, NewInitializationError(fmt.Sprintf("failed to parse database URL: %v", err))
return nil, fmt.Errorf("failed to parse database URL: %v", err)
}
config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) {
if n.Channel == "dbos_notifications_channel" {
Expand All @@ -180,7 +171,7 @@ func NewSystemDatabase() (SystemDatabase, error) {
}
notificationListenerConnection, err := pgconn.ConnectConfig(context.Background(), config)
if err != nil {
return nil, NewInitializationError(fmt.Sprintf("failed to connect notification listener to database: %v", err))
return nil, fmt.Errorf("failed to connect notification listener to database: %v", err)
}

return &systemDatabase{
Expand Down Expand Up @@ -1330,7 +1321,7 @@ func (s *systemDatabase) DequeueWorkflows(ctx context.Context, queue WorkflowQue
getLogger().Warn("Local pending workflows on queue exceeds worker concurrency limit", "local_pending", localPendingWorkflows, "queue_name", queue.Name, "concurrency_limit", workerConcurrency)
}
availableWorkerTasks := max(workerConcurrency-localPendingWorkflows, 0)
maxTasks = uint(availableWorkerTasks)
maxTasks = availableWorkerTasks
}

// Check global concurrency limit
Expand All @@ -1345,8 +1336,8 @@ func (s *systemDatabase) DequeueWorkflows(ctx context.Context, queue WorkflowQue
getLogger().Warn("Total pending workflows on queue exceeds global concurrency limit", "total_pending", globalPendingWorkflows, "queue_name", queue.Name, "concurrency_limit", concurrency)
}
availableTasks := max(concurrency-globalPendingWorkflows, 0)
if uint(availableTasks) < maxTasks {
maxTasks = uint(availableTasks)
if availableTasks < maxTasks {
maxTasks = availableTasks
}
}
}
Expand Down Expand Up @@ -1383,8 +1374,7 @@ func (s *systemDatabase) DequeueWorkflows(ctx context.Context, queue WorkflowQue
%s`, lockClause)
}

// Add limit if maxTasks is finite
if maxTasks > 0 {
if maxTasks >= 0 {
query += fmt.Sprintf(" LIMIT %d", int(maxTasks))
}

Expand Down
Loading
Loading