Skip to content

Commit 8b55202

Browse files
authored
System DB config (#33)
Let users provide the system db url programmatically
1 parent ee7eb93 commit 8b55202

File tree

10 files changed

+122
-52
lines changed

10 files changed

+122
-52
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,4 +173,4 @@ Install the DBOS Transact package in your program:
173173
github.com/dbos-inc/dbos-transact-go
174174
```
175175

176-
You can store and export a Postgres connection string in the `DBOS_DATABASE_URL` environment variable for DBOS to manage your workflows state. By default, DBOS will use `postgres://postgres:${PGPASSWORD}@localhost:5432/dbos?sslmode=disable`.
176+
You can store and export a Postgres connection string in the `DBOS_SYSTEM_DATABASE_URL` environment variable for DBOS to manage your workflows state. By default, DBOS will use `postgres://postgres:${PGPASSWORD}@localhost:5432/dbos?sslmode=disable`.

dbos/admin-server_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ import (
1313

1414
func TestAdminServer(t *testing.T) {
1515
// Skip if database is not available
16-
databaseURL := os.Getenv("DBOS_DATABASE_URL")
16+
databaseURL := os.Getenv("DBOS_SYSTEM_DATABASE_URL")
1717
if databaseURL == "" && os.Getenv("PGPASSWORD") == "" {
18-
t.Skip("Database not available (DBOS_DATABASE_URL and PGPASSWORD not set), skipping DBOS integration tests")
18+
t.Skip("Database not available (DBOS_SYSTEM_DATABASE_URL and PGPASSWORD not set), skipping DBOS integration tests")
1919
}
2020

2121
t.Run("Admin server is not started without WithAdminServer option", func(t *testing.T) {

dbos/dbos.go

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"encoding/hex"
77
"fmt"
88
"log/slog"
9+
"net/url"
910
"os"
1011
"reflect"
1112
"runtime"
@@ -78,12 +79,7 @@ func getExecutor() *executor {
7879
var logger *slog.Logger
7980

8081
func getLogger() *slog.Logger {
81-
if dbos == nil {
82-
fmt.Println("warning: DBOS instance not initialized, using default logger")
83-
return slog.New(slog.NewTextHandler(os.Stderr, nil))
84-
}
85-
if logger == nil {
86-
fmt.Println("warning: DBOS logger is nil, using default logger")
82+
if dbos == nil || logger == nil {
8783
return slog.New(slog.NewTextHandler(os.Stderr, nil))
8884
}
8985
return logger
@@ -92,6 +88,40 @@ func getLogger() *slog.Logger {
9288
type config struct {
9389
logger *slog.Logger
9490
adminServer bool
91+
databaseURL string
92+
appName string
93+
}
94+
95+
// NewConfig merges configuration from two sources in order of precedence:
96+
// 1. programmatic configuration
97+
// 2. environment variables
98+
// Finally, it applies default values if needed.
99+
func NewConfig(programmaticConfig config) *config {
100+
dbosConfig := &config{}
101+
102+
// Start with environment variables (lowest precedence)
103+
if dbURL := os.Getenv("DBOS_SYSTEM_DATABASE_URL"); dbURL != "" {
104+
dbosConfig.databaseURL = dbURL
105+
}
106+
107+
// Override with programmatic configuration (highest precedence)
108+
if len(programmaticConfig.databaseURL) > 0 {
109+
dbosConfig.databaseURL = programmaticConfig.databaseURL
110+
}
111+
if len(programmaticConfig.appName) > 0 {
112+
dbosConfig.appName = programmaticConfig.appName
113+
}
114+
// Copy over parameters that can only be set programmatically
115+
dbosConfig.logger = programmaticConfig.logger
116+
dbosConfig.adminServer = programmaticConfig.adminServer
117+
118+
// Load defaults
119+
if len(dbosConfig.databaseURL) == 0 {
120+
getLogger().Info("Using default database URL: postgres://postgres:${PGPASSWORD}@localhost:5432/dbos?sslmode=disable")
121+
password := url.QueryEscape(os.Getenv("PGPASSWORD"))
122+
dbosConfig.databaseURL = fmt.Sprintf("postgres://postgres:%s@localhost:5432/dbos?sslmode=disable", password)
123+
}
124+
return dbosConfig
95125
}
96126

97127
type LaunchOption func(*config)
@@ -108,18 +138,26 @@ func WithAdminServer() LaunchOption {
108138
}
109139
}
110140

141+
func WithDatabaseURL(url string) LaunchOption {
142+
return func(config *config) {
143+
config.databaseURL = url
144+
}
145+
}
146+
111147
func Launch(options ...LaunchOption) error {
112148
if dbos != nil {
113149
fmt.Println("warning: DBOS instance already initialized, skipping re-initialization")
114150
return NewInitializationError("DBOS already initialized")
115151
}
116152

153+
// Load & process the configuration
117154
config := &config{
118155
logger: slog.New(slog.NewTextHandler(os.Stderr, nil)),
119156
}
120157
for _, option := range options {
121158
option(config)
122159
}
160+
config = NewConfig(*config)
123161

124162
logger = config.logger
125163

@@ -139,7 +177,7 @@ func Launch(options ...LaunchOption) error {
139177
APP_ID = os.Getenv("DBOS__APPID")
140178

141179
// Create the system database
142-
systemDB, err := NewSystemDatabase()
180+
systemDB, err := NewSystemDatabase(config.databaseURL)
143181
if err != nil {
144182
return NewInitializationError(fmt.Sprintf("failed to create system database: %v", err))
145183
}

dbos/queue.go

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,13 @@ import (
1414
)
1515

1616
var (
17-
workflowQueueRegistry = make(map[string]WorkflowQueue)
18-
DBOS_INTERNAL_QUEUE_NAME = "_dbos_internal_queue"
19-
_ = NewWorkflowQueue(DBOS_INTERNAL_QUEUE_NAME)
17+
workflowQueueRegistry = make(map[string]WorkflowQueue)
18+
_ = NewWorkflowQueue(_DBOS_INTERNAL_QUEUE_NAME)
19+
)
20+
21+
const (
22+
_DBOS_INTERNAL_QUEUE_NAME = "_dbos_internal_queue"
23+
_DEFAULT_MAX_TASKS_PER_ITERATION = 100
2024
)
2125

2226
// RateLimiter represents a rate limiting configuration
@@ -31,7 +35,7 @@ type WorkflowQueue struct {
3135
GlobalConcurrency *int
3236
PriorityEnabled bool
3337
Limiter *RateLimiter
34-
MaxTasksPerIteration uint
38+
MaxTasksPerIteration int
3539
}
3640

3741
// QueueOption is a functional option for configuring a workflow queue
@@ -61,7 +65,7 @@ func WithRateLimiter(limiter *RateLimiter) QueueOption {
6165
}
6266
}
6367

64-
func WithMaxTasksPerIteration(maxTasks uint) QueueOption {
68+
func WithMaxTasksPerIteration(maxTasks int) QueueOption {
6569
return func(q *WorkflowQueue) {
6670
q.MaxTasksPerIteration = maxTasks
6771
}
@@ -84,7 +88,7 @@ func NewWorkflowQueue(name string, options ...QueueOption) WorkflowQueue {
8488
GlobalConcurrency: nil,
8589
PriorityEnabled: false,
8690
Limiter: nil,
87-
MaxTasksPerIteration: 100, // Default max tasks per iteration
91+
MaxTasksPerIteration: _DEFAULT_MAX_TASKS_PER_ITERATION,
8892
}
8993

9094
// Apply functional options
@@ -111,17 +115,14 @@ func queueRunner(ctx context.Context) {
111115

112116
pollingInterval := baseInterval
113117

114-
// XXX doing this lets the dequeue and the task invokation survive the context cancellation
115-
// We might be OK with not doing this. During the tests it results in all sorts of error inside the two functions above due to context cancellation
116-
runnerContext := context.WithoutCancel(ctx)
117118
for {
118119
hasBackoffError := false
119120

120121
// Iterate through all queues in the registry
121122
for queueName, queue := range workflowQueueRegistry {
122123
getLogger().Debug("Processing queue", "queue_name", queueName)
123124
// Call DequeueWorkflows for each queue
124-
dequeuedWorkflows, err := getExecutor().systemDB.DequeueWorkflows(runnerContext, queue)
125+
dequeuedWorkflows, err := getExecutor().systemDB.DequeueWorkflows(ctx, queue)
125126
if err != nil {
126127
if pgErr, ok := err.(*pgconn.PgError); ok {
127128
switch pgErr.Code {
@@ -164,7 +165,7 @@ func queueRunner(ctx context.Context) {
164165
}
165166
}
166167

167-
_, err := registeredWorkflow.wrappedFunction(runnerContext, input, WithWorkflowID(workflow.id))
168+
_, err := registeredWorkflow.wrappedFunction(ctx, input, WithWorkflowID(workflow.id))
168169
if err != nil {
169170
getLogger().Error("Error recovering workflow", "error", err)
170171
}

dbos/queues_test.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,7 @@ func TestGlobalConcurrency(t *testing.T) {
405405

406406
// Wait for the first workflow to start
407407
workflowEvent1.Wait()
408+
time.Sleep(2 * time.Second) // Wait for a few seconds to let the queue runner loop
408409

409410
// Ensure the second workflow has not started yet
410411
if workflowEvent2.IsSet {
@@ -504,8 +505,12 @@ func TestWorkerConcurrency(t *testing.T) {
504505
t.Fatalf("expected 3 workflows to be enqueued, got %d", len(workflows))
505506
}
506507

508+
// Stop the queue runner before changing executor ID to avoid race conditions
509+
stopQueueRunner()
507510
// Change the EXECUTOR_ID global variable to a different value
508511
EXECUTOR_ID = "worker-2"
512+
// Restart the queue runner
513+
restartQueueRunner()
509514

510515
// Wait for the second workflow to start on the second worker
511516
startEvents[1].Wait()
@@ -533,8 +538,12 @@ func TestWorkerConcurrency(t *testing.T) {
533538
if result1 != 0 {
534539
t.Fatalf("expected result from blocking workflow 1 to be 0, got %v", result1)
535540
}
541+
// Stop the queue runner before changing executor ID to avoid race conditions
542+
stopQueueRunner()
536543
// Change the executor again and wait for the third workflow to start
537544
EXECUTOR_ID = "local"
545+
// Restart the queue runner
546+
restartQueueRunner()
538547
startEvents[2].Wait()
539548
// Ensure the fourth workflow is not started yet
540549
if startEvents[3].IsSet {
@@ -561,8 +570,12 @@ func TestWorkerConcurrency(t *testing.T) {
561570
if result2 != 1 {
562571
t.Fatalf("expected result from blocking workflow 2 to be 1, got %v", result2)
563572
}
573+
// Stop the queue runner before changing executor ID to avoid race conditions
574+
stopQueueRunner()
564575
// change executor again and wait for the fourth workflow to start
565576
EXECUTOR_ID = "worker-2"
577+
// Restart the queue runner
578+
restartQueueRunner()
566579
startEvents[3].Wait()
567580
// Check no workflow is enqueued
568581
workflows, err = getExecutor().systemDB.ListWorkflows(context.Background(), ListWorkflowsDBInput{

dbos/system_database.go

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@ import (
55
"embed"
66
"errors"
77
"fmt"
8-
"net/url"
9-
"os"
108
"strings"
119
"sync"
1210
"time"
@@ -89,6 +87,7 @@ func createDatabaseIfNotExists(databaseURL string) error {
8987
if err != nil {
9088
return NewInitializationError(fmt.Sprintf("failed to create database %s: %v", dbName, err))
9189
}
90+
getLogger().Info("Database created", "name", dbName)
9291
}
9392

9493
return nil
@@ -125,36 +124,28 @@ func runMigrations(databaseURL string) error {
125124
}
126125

127126
// New creates a new SystemDatabase instance and runs migrations
128-
func NewSystemDatabase() (SystemDatabase, error) {
129-
// TODO: pass proper config
130-
databaseURL := os.Getenv("DBOS_DATABASE_URL")
131-
if databaseURL == "" {
132-
fmt.Println("DBOS_DATABASE_URL not set, using default: postgres://postgres:${PGPASSWORD}@localhost:5432/dbos?sslmode=disable")
133-
password := url.QueryEscape(os.Getenv("PGPASSWORD"))
134-
databaseURL = fmt.Sprintf("postgres://postgres:%s@localhost:5432/dbos?sslmode=disable", password)
135-
}
136-
127+
func NewSystemDatabase(databaseURL string) (SystemDatabase, error) {
137128
// Create the database if it doesn't exist
138129
if err := createDatabaseIfNotExists(databaseURL); err != nil {
139-
return nil, NewInitializationError(fmt.Sprintf("failed to create database: %v", err))
130+
return nil, fmt.Errorf("failed to create database: %v", err)
140131
}
141132

142133
// Run migrations first
143134
if err := runMigrations(databaseURL); err != nil {
144-
return nil, NewInitializationError(fmt.Sprintf("failed to run migrations: %v", err))
135+
return nil, fmt.Errorf("failed to run migrations: %v", err)
145136
}
146137

147138
// Create pgx pool
148139
pool, err := pgxpool.New(context.Background(), databaseURL)
149140
if err != nil {
150-
return nil, NewInitializationError(fmt.Sprintf("failed to create connection pool: %v", err))
141+
return nil, fmt.Errorf("failed to create connection pool: %v", err)
151142
}
152143

153144
// Test the connection
154145
// FIXME: remove this
155146
if err := pool.Ping(context.Background()); err != nil {
156147
pool.Close()
157-
return nil, NewInitializationError(fmt.Sprintf("failed to ping database: %v", err))
148+
return nil, fmt.Errorf("failed to ping database: %v", err)
158149
}
159150

160151
// Create a map of notification payloads to channels
@@ -163,7 +154,7 @@ func NewSystemDatabase() (SystemDatabase, error) {
163154
// Create a connection to listen on notifications
164155
config, err := pgconn.ParseConfig(databaseURL)
165156
if err != nil {
166-
return nil, NewInitializationError(fmt.Sprintf("failed to parse database URL: %v", err))
157+
return nil, fmt.Errorf("failed to parse database URL: %v", err)
167158
}
168159
config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) {
169160
if n.Channel == "dbos_notifications_channel" {
@@ -180,7 +171,7 @@ func NewSystemDatabase() (SystemDatabase, error) {
180171
}
181172
notificationListenerConnection, err := pgconn.ConnectConfig(context.Background(), config)
182173
if err != nil {
183-
return nil, NewInitializationError(fmt.Sprintf("failed to connect notification listener to database: %v", err))
174+
return nil, fmt.Errorf("failed to connect notification listener to database: %v", err)
184175
}
185176

186177
return &systemDatabase{
@@ -1330,7 +1321,7 @@ func (s *systemDatabase) DequeueWorkflows(ctx context.Context, queue WorkflowQue
13301321
getLogger().Warn("Local pending workflows on queue exceeds worker concurrency limit", "local_pending", localPendingWorkflows, "queue_name", queue.Name, "concurrency_limit", workerConcurrency)
13311322
}
13321323
availableWorkerTasks := max(workerConcurrency-localPendingWorkflows, 0)
1333-
maxTasks = uint(availableWorkerTasks)
1324+
maxTasks = availableWorkerTasks
13341325
}
13351326

13361327
// Check global concurrency limit
@@ -1345,8 +1336,8 @@ func (s *systemDatabase) DequeueWorkflows(ctx context.Context, queue WorkflowQue
13451336
getLogger().Warn("Total pending workflows on queue exceeds global concurrency limit", "total_pending", globalPendingWorkflows, "queue_name", queue.Name, "concurrency_limit", concurrency)
13461337
}
13471338
availableTasks := max(concurrency-globalPendingWorkflows, 0)
1348-
if uint(availableTasks) < maxTasks {
1349-
maxTasks = uint(availableTasks)
1339+
if availableTasks < maxTasks {
1340+
maxTasks = availableTasks
13501341
}
13511342
}
13521343
}
@@ -1383,8 +1374,7 @@ func (s *systemDatabase) DequeueWorkflows(ctx context.Context, queue WorkflowQue
13831374
%s`, lockClause)
13841375
}
13851376

1386-
// Add limit if maxTasks is finite
1387-
if maxTasks > 0 {
1377+
if maxTasks >= 0 {
13881378
query += fmt.Sprintf(" LIMIT %d", int(maxTasks))
13891379
}
13901380

0 commit comments

Comments
 (0)