Skip to content

Commit fc9aa09

Browse files
authored
Fix: Accept custom pgxpool (#118)
Fixes #87
1 parent 2a091a2 commit fc9aa09

File tree

3 files changed

+108
-18
lines changed

3 files changed

+108
-18
lines changed

dbos/dbos.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"time"
1717

1818
"github.com/google/uuid"
19+
"github.com/jackc/pgx/v5/pgxpool"
1920
"github.com/robfig/cron/v3"
2021
)
2122

@@ -37,6 +38,7 @@ type Config struct {
3738
ApplicationVersion string // Application version (optional, overridden by DBOS__APPVERSION env var)
3839
ExecutorID string // Executor ID (optional, overridden by DBOS__VMID env var)
3940
Context context.Context // User Context
41+
SystemDBPool *pgxpool.Pool // Custom System Database Pool
4042
}
4143

4244
func processConfig(inputConfig *Config) (*Config, error) {
@@ -61,6 +63,7 @@ func processConfig(inputConfig *Config) (*Config, error) {
6163
ConductorAPIKey: inputConfig.ConductorAPIKey,
6264
ApplicationVersion: inputConfig.ApplicationVersion,
6365
ExecutorID: inputConfig.ExecutorID,
66+
SystemDBPool: inputConfig.SystemDBPool,
6467
}
6568

6669
// Load defaults
@@ -321,8 +324,14 @@ func NewDBOSContext(ctx context.Context, inputConfig Config) (DBOSContext, error
321324

322325
initExecutor.applicationID = os.Getenv("DBOS__APPID")
323326

327+
newSystemDatabaseInputs := newSystemDatabaseInput{
328+
databaseURL: config.DatabaseURL,
329+
customPool: config.SystemDBPool,
330+
logger: initExecutor.logger,
331+
}
332+
324333
// Create the system database
325-
systemDB, err := newSystemDatabase(initExecutor, config.DatabaseURL, initExecutor.logger)
334+
systemDB, err := newSystemDatabase(initExecutor, newSystemDatabaseInputs)
326335
if err != nil {
327336
return nil, newInitializationError(fmt.Sprintf("failed to create system database: %v", err))
328337
}

dbos/dbos_test.go

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
package dbos
22

33
import (
4+
"bytes"
45
"context"
6+
"log/slog"
57
"testing"
68
"time"
79

10+
"github.com/jackc/pgx/v5/pgxpool"
811
"github.com/stretchr/testify/assert"
912
"github.com/stretchr/testify/require"
1013
)
@@ -58,6 +61,60 @@ func TestConfig(t *testing.T) {
5861
assert.Equal(t, expectedMsg, dbosErr.Message)
5962
})
6063

64+
t.Run("NewSystemDatabaseWithCustomPool", func(t *testing.T) {
65+
66+
// Logger
67+
var buf bytes.Buffer
68+
slogLogger := slog.New(slog.NewTextHandler(&buf, &slog.HandlerOptions{
69+
Level: slog.LevelDebug,
70+
}))
71+
72+
slogLogger = slogLogger.With("service", "dbos-test", "environment", "test")
73+
74+
// Custom Pool
75+
poolConfig, err := pgxpool.ParseConfig(databaseURL)
76+
require.NoError(t, err)
77+
78+
poolConfig.MaxConns = 10
79+
poolConfig.MinConns = 5
80+
poolConfig.MaxConnLifetime = 2 * time.Hour
81+
poolConfig.MaxConnIdleTime = time.Minute * 2
82+
83+
poolConfig.ConnConfig.ConnectTimeout = 10 * time.Second
84+
85+
pool, err := pgxpool.NewWithConfig(context.Background(), poolConfig)
86+
require.NoError(t, err)
87+
88+
config := Config{
89+
DatabaseURL: databaseURL,
90+
AppName: "test-custom-pool",
91+
Logger: slogLogger,
92+
SystemDBPool: pool,
93+
}
94+
95+
customdbosContext, err := NewDBOSContext(context.Background(), config)
96+
require.NoError(t, err)
97+
require.NotNil(t, customdbosContext)
98+
99+
dbosCtx, ok := customdbosContext.(*dbosContext)
100+
defer dbosCtx.Shutdown(10 * time.Second)
101+
require.True(t, ok)
102+
103+
sysDB, ok := dbosCtx.systemDB.(*sysDB)
104+
require.True(t, ok)
105+
assert.Same(t, pool, sysDB.pool, "The pool in dbosContext should be the same as the custom pool provided")
106+
107+
stats := sysDB.pool.Stat()
108+
assert.Equal(t, int32(10), stats.MaxConns(), "MaxConns should match custom pool config")
109+
110+
sysdbConfig := sysDB.pool.Config()
111+
assert.Equal(t, int32(10), sysdbConfig.MaxConns)
112+
assert.Equal(t, int32(5), sysdbConfig.MinConns)
113+
assert.Equal(t, 2*time.Hour, sysdbConfig.MaxConnLifetime)
114+
assert.Equal(t, 2*time.Minute, sysdbConfig.MaxConnIdleTime)
115+
assert.Equal(t, 10*time.Second, sysdbConfig.ConnConfig.ConnectTimeout)
116+
})
117+
61118
t.Run("FailsWithoutDatabaseURL", func(t *testing.T) {
62119
config := Config{
63120
AppName: "test-app",

dbos/system_database.go

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -229,8 +229,19 @@ func runMigrations(databaseURL string) error {
229229
return nil
230230
}
231231

232+
type newSystemDatabaseInput struct {
233+
databaseURL string
234+
customPool *pgxpool.Pool
235+
logger *slog.Logger
236+
}
237+
232238
// New creates a new SystemDatabase instance and runs migrations
233-
func newSystemDatabase(ctx context.Context, databaseURL string, logger *slog.Logger) (systemDatabase, error) {
239+
func newSystemDatabase(ctx context.Context, inputs newSystemDatabaseInput) (systemDatabase, error) {
240+
// Dereference fields from inputs
241+
databaseURL := inputs.databaseURL
242+
customPool := inputs.customPool
243+
logger := inputs.logger
244+
234245
// Create the database if it doesn't exist
235246
if err := createDatabaseIfNotExists(ctx, databaseURL, logger); err != nil {
236247
return nil, fmt.Errorf("failed to create database: %v", err)
@@ -241,24 +252,37 @@ func newSystemDatabase(ctx context.Context, databaseURL string, logger *slog.Log
241252
return nil, fmt.Errorf("failed to run migrations: %v", err)
242253
}
243254

244-
// Parse the connection string to get a config
245-
config, err := pgxpool.ParseConfig(databaseURL)
246-
if err != nil {
247-
return nil, fmt.Errorf("failed to parse database URL: %v", err)
248-
}
249-
// Set pool configuration
250-
config.MaxConns = 20
251-
config.MinConns = 0
252-
config.MaxConnLifetime = time.Hour
253-
config.MaxConnIdleTime = time.Minute * 5
255+
// pool
256+
var pool *pgxpool.Pool
254257

255-
// Add acquire timeout to prevent indefinite blocking
256-
config.ConnConfig.ConnectTimeout = 10 * time.Second
258+
if customPool != nil {
259+
260+
pool = customPool
261+
262+
} else {
263+
264+
// Parse the connection string to get a config
265+
config, err := pgxpool.ParseConfig(databaseURL)
266+
if err != nil {
267+
return nil, fmt.Errorf("failed to parse database URL: %v", err)
268+
}
269+
270+
// Set pool configuration
271+
config.MaxConns = 20
272+
config.MinConns = 0
273+
config.MaxConnLifetime = time.Hour
274+
config.MaxConnIdleTime = time.Minute * 5
275+
276+
// Add acquire timeout to prevent indefinite blocking
277+
config.ConnConfig.ConnectTimeout = 10 * time.Second
278+
279+
// Create pool with configuration
280+
newPool, err := pgxpool.NewWithConfig(ctx, config)
281+
if err != nil {
282+
return nil, fmt.Errorf("failed to create connection pool: %v", err)
283+
}
284+
pool = newPool
257285

258-
// Create pool with configuration
259-
pool, err := pgxpool.NewWithConfig(ctx, config)
260-
if err != nil {
261-
return nil, fmt.Errorf("failed to create connection pool: %v", err)
262286
}
263287

264288
// Test the connection

0 commit comments

Comments
 (0)