Skip to content

Commit 633a3f8

Browse files
committed
use custom pool for everything in systemp db
1 parent 5eb5ab1 commit 633a3f8

File tree

1 file changed

+113
-99
lines changed

1 file changed

+113
-99
lines changed

dbos/system_database.go

Lines changed: 113 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -66,35 +66,40 @@ type systemDatabase interface {
6666
}
6767

6868
type sysDB struct {
69-
pool *pgxpool.Pool
70-
notificationListenerConnection *pgconn.PgConn
71-
notificationLoopDone chan struct{}
72-
notificationsMap *sync.Map
73-
logger *slog.Logger
74-
schema string
75-
launched bool
69+
pool *pgxpool.Pool
70+
notificationLoopDone chan struct{}
71+
notificationsMap *sync.Map
72+
logger *slog.Logger
73+
schema string
74+
launched bool
7675
}
7776

7877
/*******************************/
7978
/******* INITIALIZATION ********/
8079
/*******************************/
8180

8281
// createDatabaseIfNotExists creates the database if it doesn't exist
83-
func createDatabaseIfNotExists(ctx context.Context, databaseURL string, logger *slog.Logger) error {
84-
// Connect to the postgres database
85-
parsedURL, err := pgx.ParseConfig(databaseURL)
86-
if err != nil {
87-
return newInitializationError(fmt.Sprintf("failed to parse database URL: %v", err))
82+
func createDatabaseIfNotExists(ctx context.Context, pool *pgxpool.Pool, logger *slog.Logger) error {
83+
// Try to acquire a connection from the pool first
84+
poolConn, err := pool.Acquire(ctx)
85+
if err == nil {
86+
// Pool connection works, database likely exists
87+
poolConn.Release()
88+
return nil
8889
}
90+
// Fall through to database creation
8991

90-
dbName := parsedURL.Database
92+
// Get the database name from the pool config
93+
poolConfig := pool.Config()
94+
dbName := poolConfig.ConnConfig.Database
9195
if dbName == "" {
92-
return newInitializationError("database name not found in URL")
96+
return newInitializationError("database name not found in pool configuration")
9397
}
9498

95-
serverURL := parsedURL.Copy()
96-
serverURL.Database = "postgres"
97-
conn, err := pgx.ConnectConfig(ctx, serverURL)
99+
// Create a connection to the postgres database to create the target database
100+
serverConfig := poolConfig.ConnConfig.Copy()
101+
serverConfig.Database = "postgres"
102+
conn, err := pgx.ConnectConfig(ctx, serverConfig)
98103
if err != nil {
99104
return newInitializationError(fmt.Sprintf("failed to connect to PostgreSQL server: %v", err))
100105
}
@@ -146,7 +151,7 @@ const (
146151
_DB_RETRY_INTERVAL = 1 * time.Second
147152
)
148153

149-
func runMigrations(databaseURL string, schema string) error {
154+
func runMigrations(pool *pgxpool.Pool, schema string) error {
150155
// Process the migration SQL with fmt.Sprintf (22 schema placeholders)
151156
sanitizedSchema := pgx.Identifier{schema}.Sanitize()
152157
migrationSQL := fmt.Sprintf(migration1SQL,
@@ -161,13 +166,6 @@ func runMigrations(databaseURL string, schema string) error {
161166
{version: 1, sql: migrationSQL},
162167
}
163168

164-
// Connect to the database
165-
pool, err := pgxpool.New(context.Background(), databaseURL)
166-
if err != nil {
167-
return fmt.Errorf("failed to create connection pool: %v", err)
168-
}
169-
defer pool.Close()
170-
171169
// Begin transaction for atomic migration execution
172170
ctx := context.Background()
173171
tx, err := pool.Begin(ctx)
@@ -258,19 +256,10 @@ func newSystemDatabase(ctx context.Context, inputs newSystemDatabaseInput) (syst
258256
return nil, fmt.Errorf("database schema cannot be empty")
259257
}
260258

261-
// Create the database if it doesn't exist
262-
if err := createDatabaseIfNotExists(ctx, databaseURL, logger); err != nil {
263-
return nil, fmt.Errorf("failed to create database: %v", err)
264-
}
265-
266-
// Run migrations first
267-
if err := runMigrations(databaseURL, databaseSchema); err != nil {
268-
return nil, fmt.Errorf("failed to run migrations: %v", err)
269-
}
270-
271-
// pool
259+
// Configure a connection pool
272260
var pool *pgxpool.Pool
273261
if customPool != nil {
262+
logger.Info("Using custom database connection pool")
274263
pool = customPool
275264
} else {
276265
// Parse the connection string to get a config
@@ -296,6 +285,18 @@ func newSystemDatabase(ctx context.Context, inputs newSystemDatabaseInput) (syst
296285
pool = newPool
297286
}
298287

288+
// Create the database if it doesn't exist
289+
if err := createDatabaseIfNotExists(ctx, pool, logger); err != nil {
290+
pool.Close()
291+
return nil, fmt.Errorf("failed to create database: %v", err)
292+
}
293+
294+
// Run migrations
295+
if err := runMigrations(pool, databaseSchema); err != nil {
296+
pool.Close()
297+
return nil, fmt.Errorf("failed to run migrations: %v", err)
298+
}
299+
299300
// Test the connection
300301
if err := pool.Ping(ctx); err != nil {
301302
pool.Close()
@@ -305,32 +306,12 @@ func newSystemDatabase(ctx context.Context, inputs newSystemDatabaseInput) (syst
305306
// Create a map of notification payloads to channels
306307
notificationsMap := &sync.Map{}
307308

308-
// Create a connection to listen on notifications
309-
notifierConnConfig, err := pgconn.ParseConfig(databaseURL)
310-
if err != nil {
311-
return nil, fmt.Errorf("failed to parse database URL: %v", err)
312-
}
313-
notifierConnConfig.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) {
314-
if n.Channel == _DBOS_NOTIFICATIONS_CHANNEL || n.Channel == _DBOS_WORKFLOW_EVENTS_CHANNEL {
315-
// Check if an entry exists in the map, indexed by the payload
316-
// If yes, broadcast on the condition variable so listeners can wake up
317-
if cond, exists := notificationsMap.Load(n.Payload); exists {
318-
cond.(*sync.Cond).Broadcast()
319-
}
320-
}
321-
}
322-
notificationListenerConnection, err := pgconn.ConnectConfig(ctx, notifierConnConfig)
323-
if err != nil {
324-
return nil, fmt.Errorf("failed to connect notification listener to database: %v", err)
325-
}
326-
327309
return &sysDB{
328-
pool: pool,
329-
notificationListenerConnection: notificationListenerConnection,
330-
notificationsMap: notificationsMap,
331-
notificationLoopDone: make(chan struct{}),
332-
logger: logger.With("service", "system_database"),
333-
schema: databaseSchema,
310+
pool: pool,
311+
notificationsMap: notificationsMap,
312+
notificationLoopDone: make(chan struct{}),
313+
logger: logger.With("service", "system_database"),
314+
schema: databaseSchema,
334315
}, nil
335316
}
336317

@@ -342,18 +323,12 @@ func (s *sysDB) launch(ctx context.Context) {
342323

343324
func (s *sysDB) shutdown(ctx context.Context, timeout time.Duration) {
344325
s.logger.Debug("DBOS: Closing system database connection pool")
326+
345327
if s.pool != nil {
328+
// Will block until every acquired connection is released
346329
s.pool.Close()
347330
}
348331

349-
// Context wasn't cancelled, let's manually close
350-
if !errors.Is(ctx.Err(), context.Canceled) {
351-
err := s.notificationListenerConnection.Close(ctx)
352-
if err != nil {
353-
s.logger.Error("Failed to close notification listener connection", "error", err)
354-
}
355-
}
356-
357332
if s.launched {
358333
// Wait for the notification loop to exit
359334
s.logger.Debug("DBOS: Waiting for notification listener loop to finish")
@@ -1585,52 +1560,91 @@ func (s *sysDB) notificationListenerLoop(ctx context.Context) {
15851560
s.notificationLoopDone <- struct{}{}
15861561
}()
15871562

1588-
s.logger.Debug("DBOS: Starting notification listener loop")
1589-
mrr := s.notificationListenerConnection.Exec(ctx, fmt.Sprintf("LISTEN %s; LISTEN %s", _DBOS_NOTIFICATIONS_CHANNEL, _DBOS_WORKFLOW_EVENTS_CHANNEL))
1590-
results, err := mrr.ReadAll()
1591-
if err != nil {
1592-
s.logger.Error("Failed to listen on notification channels", "error", err)
1593-
return
1563+
acquire := func(ctx context.Context) (*pgxpool.Conn, error) {
1564+
// Acquire a connection from the pool and set up LISTEN on the notifications channels
1565+
pc, err := s.pool.Acquire(ctx)
1566+
if err != nil {
1567+
return nil, err
1568+
}
1569+
tx, err := pc.Begin(ctx)
1570+
if err != nil {
1571+
pc.Release()
1572+
return nil, err
1573+
}
1574+
if _, err = tx.Exec(ctx, fmt.Sprintf("LISTEN %s", _DBOS_NOTIFICATIONS_CHANNEL)); err != nil {
1575+
_ = tx.Rollback(ctx)
1576+
pc.Release()
1577+
return nil, err
1578+
}
1579+
if _, err = tx.Exec(ctx, fmt.Sprintf("LISTEN %s", _DBOS_WORKFLOW_EVENTS_CHANNEL)); err != nil {
1580+
_ = tx.Rollback(ctx)
1581+
pc.Release()
1582+
return nil, err
1583+
}
1584+
if err = tx.Commit(ctx); err != nil {
1585+
_ = tx.Rollback(ctx)
1586+
pc.Release()
1587+
return nil, err
1588+
}
1589+
return pc, nil
15941590
}
1595-
err = mrr.Close()
1591+
1592+
s.logger.Debug("DBOS: Starting notification listener loop")
1593+
1594+
poolConn, err := acquire(ctx)
15961595
if err != nil {
1597-
s.logger.Error("Failed to close connection after setting notification listeners", "error", err)
1596+
s.logger.Error("Failed to acquire listener connection", "error", err)
15981597
return
15991598
}
1600-
1601-
for _, result := range results {
1602-
if result.Err != nil {
1603-
s.logger.Error("Error listening on notification channels", "error", result.Err)
1604-
return
1605-
}
1606-
}
1599+
defer poolConn.Release()
16071600

16081601
retryAttempt := 0
16091602
for {
16101603
// Block until a notification is received. OnNotification will be called when a notification is received.
16111604
// WaitForNotification handles context cancellation: https://github.com/jackc/pgx/blob/15bca4a4e14e0049777c1245dba4c16300fe4fd0/pgconn/pgconn.go#L1050
1612-
err := s.notificationListenerConnection.WaitForNotification(ctx)
1605+
n, err := poolConn.Conn().WaitForNotification(ctx)
16131606
if err != nil {
1614-
// Context cancellation
1615-
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
1616-
s.logger.Debug("Notification listener loop exiting due to context cancellation", "cause", context.Cause(ctx), "error", err)
1607+
// Context cancellation -> graceful exit
1608+
if ctx.Err() != nil || strings.Contains(err.Error(), "pool closed") {
1609+
s.logger.Debug("Notification listener exiting (context canceled or pool closed)", "cause", context.Cause(ctx), "error", err)
16171610
return
16181611
}
1619-
1620-
// Connection closed (during shutdown) - exit gracefully
1621-
if s.notificationListenerConnection.IsClosed() {
1622-
s.logger.Info("Notification listener loop exiting due to connection closure")
1623-
return
1612+
// If the underlying connection is closed, attempt to re-acquire a new one
1613+
if poolConn.Conn().IsClosed() {
1614+
s.logger.Error("Notification listener connection closed. re-acquiring")
1615+
poolConn.Release()
1616+
for {
1617+
if ctx.Err() != nil || strings.Contains(err.Error(), "pool closed") {
1618+
s.logger.Debug("Notification listener exiting (context canceled or pool closed)", "cause", context.Cause(ctx), "error", err)
1619+
return
1620+
}
1621+
poolConn, err = acquire(ctx)
1622+
if err == nil {
1623+
retryAttempt = 0
1624+
break
1625+
}
1626+
s.logger.Error("failed to re-acquire connection for notification listener", "error", err)
1627+
time.Sleep(backoffWithJitter(retryAttempt))
1628+
retryAttempt++
1629+
}
1630+
continue
16241631
}
1625-
1626-
// Other errors - log and retry.
1632+
// Other transient errors. Backoff and continue on same conn
16271633
s.logger.Error("Error waiting for notification", "error", err)
16281634
time.Sleep(backoffWithJitter(retryAttempt))
1629-
retryAttempt += 1
1635+
retryAttempt++
16301636
continue
1631-
} else {
1632-
if retryAttempt > 0 {
1633-
retryAttempt -= 1
1637+
}
1638+
1639+
// Success: reduce backoff pressure
1640+
if retryAttempt > 0 {
1641+
retryAttempt--
1642+
}
1643+
1644+
// Handle notifications
1645+
if n.Channel == _DBOS_NOTIFICATIONS_CHANNEL || n.Channel == _DBOS_WORKFLOW_EVENTS_CHANNEL {
1646+
if cond, ok := s.notificationsMap.Load(n.Payload); ok {
1647+
cond.(*sync.Cond).Broadcast()
16341648
}
16351649
}
16361650
}

0 commit comments

Comments
 (0)