Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -173,4 +173,4 @@ Install the DBOS Transact package in your program:
github.com/dbos-inc/dbos-transact-go
```

You can store and export a Postgres connection string in the `DBOS_DATABASE_URL` environment variable for DBOS to manage your workflows state. By default, DBOS will use `postgres://postgres:${PGPASSWORD}@localhost:5432/dbos?sslmode=disable`.
You can store and export a Postgres connection string in the `DBOS_SYSTEM_DATABASE_URL` environment variable for DBOS to manage your workflows state. By default, DBOS will use `postgres://postgres:${PGPASSWORD}@localhost:5432/dbos?sslmode=disable`.
4 changes: 2 additions & 2 deletions dbos/admin-server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ import (

func TestAdminServer(t *testing.T) {
// Skip if database is not available
databaseURL := os.Getenv("DBOS_DATABASE_URL")
databaseURL := os.Getenv("DBOS_SYSTEM_DATABASE_URL")
if databaseURL == "" && os.Getenv("PGPASSWORD") == "" {
t.Skip("Database not available (DBOS_DATABASE_URL and PGPASSWORD not set), skipping DBOS integration tests")
t.Skip("Database not available (DBOS_SYSTEM_DATABASE_URL and PGPASSWORD not set), skipping DBOS integration tests")
}

t.Run("Admin server is not started without WithAdminServer option", func(t *testing.T) {
Expand Down
52 changes: 45 additions & 7 deletions dbos/dbos.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/hex"
"fmt"
"log/slog"
"net/url"
"os"
"reflect"
"runtime"
Expand Down Expand Up @@ -78,12 +79,7 @@ func getExecutor() *executor {
var logger *slog.Logger

func getLogger() *slog.Logger {
if dbos == nil {
fmt.Println("warning: DBOS instance not initialized, using default logger")
return slog.New(slog.NewTextHandler(os.Stderr, nil))
}
if logger == nil {
fmt.Println("warning: DBOS logger is nil, using default logger")
if dbos == nil || logger == nil {
return slog.New(slog.NewTextHandler(os.Stderr, nil))
}
return logger
Expand All @@ -92,6 +88,40 @@ func getLogger() *slog.Logger {
type config struct {
logger *slog.Logger
adminServer bool
databaseURL string
appName string
}

// NewConfig merges configuration from two sources in order of precedence:
// 1. programmatic configuration
// 2. environment variables
// Finally, it applies default values if needed.
func NewConfig(programmaticConfig config) *config {
dbosConfig := &config{}

// Start with environment variables (lowest precedence)
if dbURL := os.Getenv("DBOS_SYSTEM_DATABASE_URL"); dbURL != "" {
dbosConfig.databaseURL = dbURL
}

// Override with programmatic configuration (highest precedence)
if len(programmaticConfig.databaseURL) > 0 {
dbosConfig.databaseURL = programmaticConfig.databaseURL
}
if len(programmaticConfig.appName) > 0 {
dbosConfig.appName = programmaticConfig.appName
}
// Copy over parameters that can only be set programmatically
dbosConfig.logger = programmaticConfig.logger
dbosConfig.adminServer = programmaticConfig.adminServer

// Load defaults
if len(dbosConfig.databaseURL) == 0 {
getLogger().Info("Using default database URL: postgres://postgres:${PGPASSWORD}@localhost:5432/dbos?sslmode=disable")
password := url.QueryEscape(os.Getenv("PGPASSWORD"))
dbosConfig.databaseURL = fmt.Sprintf("postgres://postgres:%s@localhost:5432/dbos?sslmode=disable", password)
}
return dbosConfig
}

type LaunchOption func(*config)
Expand All @@ -108,18 +138,26 @@ func WithAdminServer() LaunchOption {
}
}

func WithDatabaseURL(url string) LaunchOption {
return func(config *config) {
config.databaseURL = url
}
}

func Launch(options ...LaunchOption) error {
if dbos != nil {
fmt.Println("warning: DBOS instance already initialized, skipping re-initialization")
return NewInitializationError("DBOS already initialized")
}

// Load & process the configuration
config := &config{
logger: slog.New(slog.NewTextHandler(os.Stderr, nil)),
}
for _, option := range options {
option(config)
}
config = NewConfig(*config)

logger = config.logger

Expand All @@ -139,7 +177,7 @@ func Launch(options ...LaunchOption) error {
APP_ID = os.Getenv("DBOS__APPID")

// Create the system database
systemDB, err := NewSystemDatabase()
systemDB, err := NewSystemDatabase(config.databaseURL)
if err != nil {
return NewInitializationError(fmt.Sprintf("failed to create system database: %v", err))
}
Expand Down
7 changes: 2 additions & 5 deletions dbos/queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,17 +111,14 @@ func queueRunner(ctx context.Context) {

pollingInterval := baseInterval

// XXX doing this lets the dequeue and the task invokation survive the context cancellation
// We might be OK with not doing this. During the tests it results in all sorts of error inside the two functions above due to context cancellation
runnerContext := context.WithoutCancel(ctx)
Comment on lines -114 to -116
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This causes a small race condition in the TestWorkerConcurrency test -- I might have to change this test altogether eventually

for {
hasBackoffError := false

// Iterate through all queues in the registry
for queueName, queue := range workflowQueueRegistry {
getLogger().Debug("Processing queue", "queue_name", queueName)
// Call DequeueWorkflows for each queue
dequeuedWorkflows, err := getExecutor().systemDB.DequeueWorkflows(runnerContext, queue)
dequeuedWorkflows, err := getExecutor().systemDB.DequeueWorkflows(ctx, queue)
if err != nil {
if pgErr, ok := err.(*pgconn.PgError); ok {
switch pgErr.Code {
Expand Down Expand Up @@ -164,7 +161,7 @@ func queueRunner(ctx context.Context) {
}
}

_, err := registeredWorkflow.wrappedFunction(runnerContext, input, WithWorkflowID(workflow.id))
_, err := registeredWorkflow.wrappedFunction(ctx, input, WithWorkflowID(workflow.id))
if err != nil {
getLogger().Error("Error recovering workflow", "error", err)
}
Expand Down
12 changes: 12 additions & 0 deletions dbos/queues_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -504,8 +504,12 @@ func TestWorkerConcurrency(t *testing.T) {
t.Fatalf("expected 3 workflows to be enqueued, got %d", len(workflows))
}

// Stop the queue runner before changing executor ID to avoid race conditions
stopQueueRunner()
// Change the EXECUTOR_ID global variable to a different value
EXECUTOR_ID = "worker-2"
// Restart the queue runner
restartQueueRunner()

// Wait for the second workflow to start on the second worker
startEvents[1].Wait()
Expand Down Expand Up @@ -533,8 +537,12 @@ func TestWorkerConcurrency(t *testing.T) {
if result1 != 0 {
t.Fatalf("expected result from blocking workflow 1 to be 0, got %v", result1)
}
// Stop the queue runner before changing executor ID to avoid race conditions
stopQueueRunner()
// Change the executor again and wait for the third workflow to start
EXECUTOR_ID = "local"
// Restart the queue runner
restartQueueRunner()
startEvents[2].Wait()
// Ensure the fourth workflow is not started yet
if startEvents[3].IsSet {
Expand All @@ -561,8 +569,12 @@ func TestWorkerConcurrency(t *testing.T) {
if result2 != 1 {
t.Fatalf("expected result from blocking workflow 2 to be 1, got %v", result2)
}
// Stop the queue runner before changing executor ID to avoid race conditions
stopQueueRunner()
// change executor again and wait for the fourth workflow to start
EXECUTOR_ID = "worker-2"
// Restart the queue runner
restartQueueRunner()
startEvents[3].Wait()
// Check no workflow is enqueued
workflows, err = getExecutor().systemDB.ListWorkflows(context.Background(), ListWorkflowsDBInput{
Expand Down
25 changes: 8 additions & 17 deletions dbos/system_database.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ import (
"embed"
"errors"
"fmt"
"net/url"
"os"
"strings"
"sync"
"time"
Expand Down Expand Up @@ -89,6 +87,7 @@ func createDatabaseIfNotExists(databaseURL string) error {
if err != nil {
return NewInitializationError(fmt.Sprintf("failed to create database %s: %v", dbName, err))
}
getLogger().Info("Database created", "name", dbName)
}

return nil
Expand Down Expand Up @@ -125,36 +124,28 @@ func runMigrations(databaseURL string) error {
}

// New creates a new SystemDatabase instance and runs migrations
func NewSystemDatabase() (SystemDatabase, error) {
// TODO: pass proper config
databaseURL := os.Getenv("DBOS_DATABASE_URL")
if databaseURL == "" {
fmt.Println("DBOS_DATABASE_URL not set, using default: postgres://postgres:${PGPASSWORD}@localhost:5432/dbos?sslmode=disable")
password := url.QueryEscape(os.Getenv("PGPASSWORD"))
databaseURL = fmt.Sprintf("postgres://postgres:%s@localhost:5432/dbos?sslmode=disable", password)
}

func NewSystemDatabase(databaseURL string) (SystemDatabase, error) {
// Create the database if it doesn't exist
if err := createDatabaseIfNotExists(databaseURL); err != nil {
return nil, NewInitializationError(fmt.Sprintf("failed to create database: %v", err))
return nil, fmt.Errorf("failed to create database: %v", err)
}

// Run migrations first
if err := runMigrations(databaseURL); err != nil {
return nil, NewInitializationError(fmt.Sprintf("failed to run migrations: %v", err))
return nil, fmt.Errorf("failed to run migrations: %v", err)
}

// Create pgx pool
pool, err := pgxpool.New(context.Background(), databaseURL)
if err != nil {
return nil, NewInitializationError(fmt.Sprintf("failed to create connection pool: %v", err))
return nil, fmt.Errorf("failed to create connection pool: %v", err)
}

// Test the connection
// FIXME: remove this
if err := pool.Ping(context.Background()); err != nil {
pool.Close()
return nil, NewInitializationError(fmt.Sprintf("failed to ping database: %v", err))
return nil, fmt.Errorf("failed to ping database: %v", err)
}

// Create a map of notification payloads to channels
Expand All @@ -163,7 +154,7 @@ func NewSystemDatabase() (SystemDatabase, error) {
// Create a connection to listen on notifications
config, err := pgconn.ParseConfig(databaseURL)
if err != nil {
return nil, NewInitializationError(fmt.Sprintf("failed to parse database URL: %v", err))
return nil, fmt.Errorf("failed to parse database URL: %v", err)
}
config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) {
if n.Channel == "dbos_notifications_channel" {
Expand All @@ -180,7 +171,7 @@ func NewSystemDatabase() (SystemDatabase, error) {
}
notificationListenerConnection, err := pgconn.ConnectConfig(context.Background(), config)
if err != nil {
return nil, NewInitializationError(fmt.Sprintf("failed to connect notification listener to database: %v", err))
return nil, fmt.Errorf("failed to connect notification listener to database: %v", err)
}

return &systemDatabase{
Expand Down
31 changes: 29 additions & 2 deletions dbos/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (
func setupDBOS(t *testing.T) {
t.Helper()

databaseURL := os.Getenv("DBOS_DATABASE_URL")
databaseURL := os.Getenv("DBOS_SYSTEM_DATABASE_URL")
if databaseURL == "" {
password := url.QueryEscape(os.Getenv("PGPASSWORD"))
databaseURL = fmt.Sprintf("postgres://postgres:%s@localhost:5432/dbos?sslmode=disable", password)
Expand All @@ -30,7 +30,7 @@ func setupDBOS(t *testing.T) {

dbName := parsedURL.Database
if dbName == "" {
t.Skip("DBOS_DATABASE_URL does not specify a database name, skipping integration test")
t.Skip("DBOS_SYSTEM_DATABASE_URL does not specify a database name, skipping integration test")
}

postgresURL := parsedURL.Copy()
Expand Down Expand Up @@ -97,6 +97,33 @@ func (e *Event) Clear() {
}

/* Helpers */

// stopQueueRunner stops the queue runner for testing purposes
func stopQueueRunner() {
if dbos != nil && dbos.queueRunnerCancelFunc != nil {
dbos.queueRunnerCancelFunc()
// Wait for queue runner to finish
<-dbos.queueRunnerDone
}
}

// restartQueueRunner restarts the queue runner for testing purposes
func restartQueueRunner() {
if dbos != nil {
// Create new context and cancel function
ctx, cancel := context.WithCancel(context.Background())
dbos.queueRunnerCtx = ctx
dbos.queueRunnerCancelFunc = cancel
dbos.queueRunnerDone = make(chan struct{})

// Start the queue runner in a goroutine
go func() {
defer close(dbos.queueRunnerDone)
queueRunner(ctx)
}()
}
}

func equal(a, b []int) bool {
if len(a) != len(b) {
return false
Expand Down
3 changes: 2 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ require (
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
go.uber.org/atomic v1.7.0 // indirect
github.com/stretchr/testify v1.10.0 // indirect
go.uber.org/atomic v1.9.0 // indirect
golang.org/x/crypto v0.39.0 // indirect
golang.org/x/sync v0.15.0 // indirect
golang.org/x/text v0.26.0 // indirect
Expand Down
8 changes: 4 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzG
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 h1:TT4fX+nBOA/+LUkobKGW1ydGcn+G3vRw9+g5HwCphpk=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0/go.mod h1:L7UH0GbB0p47T4Rri3uHjbpCFYrVrwc1I25QhNPiGK8=
go.opentelemetry.io/otel v1.29.0 h1:PdomN/Al4q/lN6iBJEN3AwPvUiHPMlt93c8bqTG5Llw=
Expand All @@ -73,8 +73,8 @@ go.opentelemetry.io/otel/metric v1.29.0 h1:vPf/HFWTNkPu1aYeIsc98l4ktOQaL6LeSoeV2
go.opentelemetry.io/otel/metric v1.29.0/go.mod h1:auu/QWieFVWx+DmQOUMgj0F8LHWdgalxXqvp7BII/W8=
go.opentelemetry.io/otel/trace v1.29.0 h1:J/8ZNK4XgR7a21DZUAsbF8pZ5Jcw1VhACmnYt39JTi4=
go.opentelemetry.io/otel/trace v1.29.0/go.mod h1:eHl3w0sp3paPkYstJOmAimxhiFXPg+MMTlEh3nsQgWQ=
go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw=
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM=
golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U=
golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8=
Expand Down
Loading