Skip to content

Commit fc7c18d

Browse files
committed
- Only create DB if we create a pool
- Validate pool first - Fix sysdb shutdown logic and add timeout on pool.Close() + only close our own pool during init failures
1 parent e851136 commit fc7c18d

File tree

2 files changed

+206
-24
lines changed

2 files changed

+206
-24
lines changed

dbos/dbos_test.go

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package dbos
33
import (
44
"context"
55
"fmt"
6+
"log/slog"
67
"testing"
78
"time"
89

@@ -463,6 +464,48 @@ func TestCustomSystemDBSchema(t *testing.T) {
463464
}
464465

465466
func TestCustomPool(t *testing.T) {
467+
// Test workflows for custom pool testing
468+
type customPoolWorkflowInput struct {
469+
PartnerWorkflowID string
470+
Message string
471+
}
472+
473+
// Workflow A: Uses Send() and GetEvent() - waits for workflow B
474+
sendGetEventWorkflowCustom := func(ctx DBOSContext, input customPoolWorkflowInput) (string, error) {
475+
// Send a message to the partner workflow
476+
err := Send(ctx, input.PartnerWorkflowID, input.Message, "custom-pool-topic")
477+
if err != nil {
478+
return "", err
479+
}
480+
481+
// Wait for an event from the partner workflow
482+
result, err := GetEvent[string](ctx, input.PartnerWorkflowID, "custom-response-key", 5*time.Hour)
483+
if err != nil {
484+
return "", err
485+
}
486+
487+
return result, nil
488+
}
489+
490+
// Workflow B: Uses Recv() and SetEvent() - waits for workflow A
491+
recvSetEventWorkflowCustom := func(ctx DBOSContext, input customPoolWorkflowInput) (string, error) {
492+
// Receive a message from the partner workflow
493+
receivedMsg, err := Recv[string](ctx, "custom-pool-topic", 5*time.Hour)
494+
if err != nil {
495+
return "", err
496+
}
497+
498+
time.Sleep(1 * time.Second)
499+
500+
// Set an event for the partner workflow
501+
err = SetEvent(ctx, "custom-response-key", "response-from-custom-pool-workflow")
502+
if err != nil {
503+
return "", err
504+
}
505+
506+
return receivedMsg, nil
507+
}
508+
466509
t.Run("NewSystemDatabaseWithCustomPool", func(t *testing.T) {
467510
// Custom Pool
468511
databaseURL := getDatabaseURL()
@@ -506,6 +549,60 @@ func TestCustomPool(t *testing.T) {
506549
assert.Equal(t, 2*time.Hour, sysdbConfig.MaxConnLifetime)
507550
assert.Equal(t, 2*time.Minute, sysdbConfig.MaxConnIdleTime)
508551
assert.Equal(t, 10*time.Second, sysdbConfig.ConnConfig.ConnectTimeout)
552+
553+
// Register the test workflows
554+
RegisterWorkflow(customdbosContext, sendGetEventWorkflowCustom)
555+
RegisterWorkflow(customdbosContext, recvSetEventWorkflowCustom)
556+
557+
// Launch the DBOS context
558+
err = customdbosContext.Launch()
559+
require.NoError(t, err)
560+
defer dbosCtx.Shutdown(1 * time.Minute)
561+
562+
// Test RunWorkflow - start both workflows that will communicate with each other
563+
workflowAID := uuid.NewString()
564+
workflowBID := uuid.NewString()
565+
566+
// Start workflow B first (receiver)
567+
handleB, err := RunWorkflow(customdbosContext, recvSetEventWorkflowCustom, customPoolWorkflowInput{
568+
PartnerWorkflowID: workflowAID,
569+
Message: "custom-pool-message-from-b",
570+
}, WithWorkflowID(workflowBID))
571+
require.NoError(t, err, "failed to start recvSetEventWorkflowCustom")
572+
573+
// Small delay to ensure workflow B is ready to receive
574+
time.Sleep(100 * time.Millisecond)
575+
576+
// Start workflow A (sender)
577+
handleA, err := RunWorkflow(customdbosContext, sendGetEventWorkflowCustom, customPoolWorkflowInput{
578+
PartnerWorkflowID: workflowBID,
579+
Message: "custom-pool-message-from-a",
580+
}, WithWorkflowID(workflowAID))
581+
require.NoError(t, err, "failed to start sendGetEventWorkflowCustom")
582+
583+
// Wait for both workflows to complete
584+
resultA, err := handleA.GetResult()
585+
require.NoError(t, err, "failed to get result from workflow A")
586+
assert.Equal(t, "response-from-custom-pool-workflow", resultA, "workflow A should receive response from workflow B")
587+
588+
resultB, err := handleB.GetResult()
589+
require.NoError(t, err, "failed to get result from workflow B")
590+
assert.Equal(t, "custom-pool-message-from-a", resultB, "workflow B should receive message from workflow A")
591+
592+
// Test GetWorkflowSteps
593+
stepsA, err := GetWorkflowSteps(customdbosContext, workflowAID)
594+
require.NoError(t, err, "failed to get workflow A steps")
595+
require.Len(t, stepsA, 3, "workflow A should have 3 steps (Send + GetEvent + Sleep)")
596+
assert.Equal(t, "DBOS.send", stepsA[0].StepName, "first step should be Send")
597+
assert.Equal(t, "DBOS.getEvent", stepsA[1].StepName, "second step should be GetEvent")
598+
assert.Equal(t, "DBOS.sleep", stepsA[2].StepName, "third step should be Sleep")
599+
600+
stepsB, err := GetWorkflowSteps(customdbosContext, workflowBID)
601+
require.NoError(t, err, "failed to get workflow B steps")
602+
require.Len(t, stepsB, 3, "workflow B should have 3 steps (Recv + Sleep + SetEvent)")
603+
assert.Equal(t, "DBOS.recv", stepsB[0].StepName, "first step should be Recv")
604+
assert.Equal(t, "DBOS.sleep", stepsB[1].StepName, "second step should be Sleep")
605+
assert.Equal(t, "DBOS.setEvent", stepsB[2].StepName, "third step should be SetEvent")
509606
})
510607

511608
wf := func(ctx DBOSContext, input string) (string, error) {
@@ -539,4 +636,71 @@ func TestCustomPool(t *testing.T) {
539636
_, err = RunWorkflow(dbosCtx, wf, "test-input")
540637
require.NoError(t, err)
541638
})
639+
640+
t.Run("InvalidCustomPool", func(t *testing.T) {
641+
databaseURL := getDatabaseURL()
642+
poolConfig, err := pgxpool.ParseConfig(databaseURL)
643+
require.NoError(t, err)
644+
poolConfig.ConnConfig.Host = "invalid-host"
645+
pool, err := pgxpool.NewWithConfig(context.Background(), poolConfig)
646+
require.NoError(t, err)
647+
648+
config := Config{
649+
DatabaseURL: databaseURL,
650+
AppName: "test-invalid-custom-pool",
651+
SystemDBPool: pool,
652+
}
653+
_, err = NewDBOSContext(context.Background(), config)
654+
require.Error(t, err)
655+
dbosErr, ok := err.(*DBOSError)
656+
require.True(t, ok, "expected DBOSError, got %T", err)
657+
assert.Equal(t, InitializationError, dbosErr.Code)
658+
expectedMsg := "Error initializing DBOS Transact: failed to create system database"
659+
assert.Contains(t, dbosErr.Message, expectedMsg)
660+
})
661+
662+
t.Run("DirectSystemDatabase", func(t *testing.T) {
663+
ctx, cancel := context.WithCancel(context.Background())
664+
databaseURL := getDatabaseURL()
665+
logger := slog.Default()
666+
667+
// Create custom pool
668+
poolConfig, err := pgxpool.ParseConfig(databaseURL)
669+
require.NoError(t, err)
670+
poolConfig.MaxConns = 15
671+
poolConfig.MinConns = 3
672+
customPool, err := pgxpool.NewWithConfig(ctx, poolConfig)
673+
require.NoError(t, err)
674+
defer customPool.Close()
675+
676+
// Create system database with custom pool
677+
sysDBInput := newSystemDatabaseInput{
678+
databaseURL: databaseURL,
679+
databaseSchema: "dbos_test_custom_direct",
680+
customPool: customPool,
681+
logger: logger,
682+
}
683+
684+
systemDB, err := newSystemDatabase(ctx, sysDBInput)
685+
require.NoError(t, err, "failed to create system database with custom pool")
686+
require.NotNil(t, systemDB)
687+
688+
// Launch the system database
689+
systemDB.launch(ctx)
690+
691+
require.Eventually(t, func() bool {
692+
conn, err := systemDB.(*sysDB).pool.Acquire(ctx)
693+
require.NoError(t, err)
694+
defer conn.Release()
695+
err = conn.Ping(ctx)
696+
require.NoError(t, err)
697+
return true
698+
}, 5*time.Second, 100*time.Millisecond, "system database should be reachable")
699+
700+
// Shutdown the system database
701+
cancel() // Cancel context
702+
shutdownTimeout := 2 * time.Second
703+
systemDB.shutdown(ctx, shutdownTimeout)
704+
assert.False(t, systemDB.(*sysDB).launched)
705+
})
542706
}

dbos/system_database.go

Lines changed: 42 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -80,15 +80,6 @@ type sysDB struct {
8080

8181
// createDatabaseIfNotExists creates the database if it doesn't exist
8282
func createDatabaseIfNotExists(ctx context.Context, pool *pgxpool.Pool, logger *slog.Logger) error {
83-
// Try to acquire a connection from the pool first
84-
poolConn, err := pool.Acquire(ctx)
85-
if err == nil {
86-
// Pool connection works, database likely exists
87-
poolConn.Release()
88-
return nil
89-
}
90-
// Fall through to database creation
91-
9283
// Get the database name from the pool config
9384
poolConfig := pool.Config()
9485
dbName := poolConfig.ConnConfig.Database
@@ -260,6 +251,16 @@ func newSystemDatabase(ctx context.Context, inputs newSystemDatabaseInput) (syst
260251
var pool *pgxpool.Pool
261252
if customPool != nil {
262253
logger.Info("Using custom database connection pool")
254+
// Verify the pool is valid
255+
poolConn, err := customPool.Acquire(ctx)
256+
if err != nil {
257+
return nil, fmt.Errorf("failed to validate custom pool: %v", err)
258+
}
259+
err = poolConn.Ping(ctx)
260+
if err != nil {
261+
return nil, fmt.Errorf("failed to validate custom pool: %v", err)
262+
}
263+
poolConn.Release()
263264
pool = customPool
264265
} else {
265266
// Parse the connection string to get a config
@@ -285,21 +286,27 @@ func newSystemDatabase(ctx context.Context, inputs newSystemDatabaseInput) (syst
285286
pool = newPool
286287
}
287288

288-
// Create the database if it doesn't exist
289-
if err := createDatabaseIfNotExists(ctx, pool, logger); err != nil {
290-
pool.Close()
291-
return nil, fmt.Errorf("failed to create database: %v", err)
289+
if customPool == nil {
290+
// Create the database if it doesn't exist
291+
if err := createDatabaseIfNotExists(ctx, pool, logger); err != nil {
292+
pool.Close()
293+
return nil, fmt.Errorf("failed to create database: %v", err)
294+
}
292295
}
293296

294297
// Run migrations
295298
if err := runMigrations(pool, databaseSchema); err != nil {
296-
pool.Close()
299+
if customPool == nil {
300+
pool.Close()
301+
}
297302
return nil, fmt.Errorf("failed to run migrations: %v", err)
298303
}
299304

300305
// Test the connection
301306
if err := pool.Ping(ctx); err != nil {
302-
pool.Close()
307+
if customPool == nil {
308+
pool.Close()
309+
}
303310
return nil, fmt.Errorf("failed to ping database: %v", err)
304311
}
305312

@@ -322,19 +329,28 @@ func (s *sysDB) launch(ctx context.Context) {
322329
}
323330

324331
func (s *sysDB) shutdown(ctx context.Context, timeout time.Duration) {
325-
s.logger.Debug("DBOS: Closing system database connection pool")
326-
327-
if s.pool != nil {
328-
// Will block until every acquired connection is released
329-
s.pool.Close()
330-
}
332+
s.logger.Debug("Closing system database connection pool")
331333

332334
if s.launched {
333335
// Wait for the notification loop to exit
334336
select {
335337
case <-s.notificationLoopDone:
336338
case <-time.After(timeout):
337-
s.logger.Warn("DBOS: Notification listener loop did not finish in time", "timeout", timeout)
339+
s.logger.Warn("Notification listener loop did not finish in time", "timeout", timeout)
340+
}
341+
}
342+
343+
if s.pool != nil {
344+
poolClose := make(chan struct{})
345+
go func() {
346+
// Will block until every acquired connection is released
347+
s.pool.Close()
348+
close(poolClose)
349+
}()
350+
select {
351+
case <-poolClose:
352+
case <-time.After(timeout):
353+
s.logger.Warn("System database connection pool did not close in time", "timeout", timeout)
338354
}
339355
}
340356

@@ -1556,6 +1572,7 @@ func (s *sysDB) sleep(ctx context.Context, input sleepInput) (time.Duration, err
15561572

15571573
func (s *sysDB) notificationListenerLoop(ctx context.Context) {
15581574
defer func() {
1575+
s.logger.Debug("Notification listener loop exiting")
15591576
s.notificationLoopDone <- struct{}{}
15601577
}()
15611578

@@ -1613,8 +1630,9 @@ func (s *sysDB) notificationListenerLoop(ctx context.Context) {
16131630
n, err := poolConn.Conn().WaitForNotification(ctx)
16141631
if err != nil {
16151632
// Context cancellation -> graceful exit
1616-
if ctx.Err() != nil || strings.Contains(err.Error(), "pool closed") {
1617-
s.logger.Debug("Notification listener exiting (context canceled or pool closed)", "cause", context.Cause(ctx), "error", err)
1633+
if ctx.Err() != nil {
1634+
s.logger.Debug("Notification listener exiting (context canceled", "cause", context.Cause(ctx), "error", err)
1635+
poolConn.Release()
16181636
return
16191637
}
16201638
// If the underlying connection is closed, attempt to re-acquire a new one

0 commit comments

Comments
 (0)