Skip to content

Commit 54a4873

Browse files
authored
Allow extensions to be loaded in pglite-socket (#871)
* allow extensions to be loaded in pglite-socket * changeset * update ubuntu packages in CI * update ubuntu packages * not update ubuntu packages * docs * cleanup * --extension parameter test; allow allow loading extensions from any installed npm package * style
1 parent 9615075 commit 54a4873

File tree

7 files changed

+461
-2
lines changed

7 files changed

+461
-2
lines changed

.changeset/curly-taxis-try.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'@electric-sql/pglite-socket': patch
3+
---
4+
5+
allow extensions to be loaded via '-e/--extensions <list>' cmd line parameter'

.github/workflows/build_and_test.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ jobs:
9595
working-directory: ./packages/pglite
9696
needs: [build-all]
9797
steps:
98+
9899
- uses: actions/checkout@v4
99100
- uses: pnpm/action-setup@v4
100101
- uses: actions/setup-node@v4

packages/pglite-socket/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ pglite-server --help
160160
- `-h, --host=HOST` - Host to bind to (default: 127.0.0.1)
161161
- `-u, --path=UNIX` - Unix socket to bind to (takes precedence over host:port)
162162
- `-v, --debug=LEVEL` - Debug level 0-5 (default: 0)
163+
- `-e, --extensions=LIST` - Comma-separated list of extensions to load (e.g., vector,pgcrypto)
163164
- `-r, --run=COMMAND` - Command to run after server starts
164165
- `--include-database-url` - Include DATABASE_URL in subprocess environment
165166
- `--shutdown-timeout=MS` - Timeout for graceful subprocess shutdown in ms (default: 5000)

packages/pglite-socket/src/scripts/server.ts

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#!/usr/bin/env node
22

33
import { PGlite, DebugLevel } from '@electric-sql/pglite'
4+
import type { Extension, Extensions } from '@electric-sql/pglite'
45
import { PGLiteSocketServer } from '../index'
56
import { parseArgs } from 'node:util'
67
import { spawn, ChildProcess } from 'node:child_process'
@@ -38,6 +39,12 @@ const args = parseArgs({
3839
default: '0',
3940
help: 'Debug level (0-5)',
4041
},
42+
extensions: {
43+
type: 'string',
44+
short: 'e',
45+
default: undefined,
46+
help: 'Comma-separated list of extensions to load (e.g., vector,pgcrypto)',
47+
},
4148
run: {
4249
type: 'string',
4350
short: 'r',
@@ -72,6 +79,9 @@ Options:
7279
-h, --host=HOST Host to bind to (default: 127.0.0.1)
7380
-u, --path=UNIX Unix socket to bind to (default: undefined). Takes precedence over host:port
7481
-v, --debug=LEVEL Debug level 0-5 (default: 0)
82+
-e, --extensions=LIST Comma-separated list of extensions to load
83+
Formats: vector, pgcrypto (built-in/contrib)
84+
@org/package/path:exportedName (npm package)
7585
-r, --run=COMMAND Command to run after server starts
7686
--include-database-url Include DATABASE_URL in subprocess environment
7787
--shutdown-timeout=MS Timeout for graceful subprocess shutdown in ms (default: 5000)
@@ -83,6 +93,7 @@ interface ServerConfig {
8393
host: string
8494
path?: string
8595
debugLevel: DebugLevel
96+
extensionNames?: string[]
8697
runCommand?: string
8798
includeDatabaseUrl: boolean
8899
shutdownTimeout: number
@@ -99,12 +110,16 @@ class PGLiteServerRunner {
99110
}
100111

101112
static parseConfig(): ServerConfig {
113+
const extensionsArg = args.values.extensions as string | undefined
102114
return {
103115
dbPath: args.values.db as string,
104116
port: parseInt(args.values.port as string, 10),
105117
host: args.values.host as string,
106118
path: args.values.path as string,
107119
debugLevel: parseInt(args.values.debug as string, 10) as DebugLevel,
120+
extensionNames: extensionsArg
121+
? extensionsArg.split(',').map((e) => e.trim())
122+
: undefined,
108123
runCommand: args.values.run as string,
109124
includeDatabaseUrl: args.values['include-database-url'] as boolean,
110125
shutdownTimeout: parseInt(args.values['shutdown-timeout'] as string, 10),
@@ -126,11 +141,86 @@ class PGLiteServerRunner {
126141
}
127142
}
128143

144+
private async importExtensions(): Promise<Extensions | undefined> {
145+
if (!this.config.extensionNames?.length) {
146+
return undefined
147+
}
148+
149+
const extensions: Extensions = {}
150+
151+
// Built-in extensions that are not in contrib
152+
const builtInExtensions = [
153+
'vector',
154+
'live',
155+
'pg_hashids',
156+
'pg_ivm',
157+
'pg_uuidv7',
158+
'pgtap',
159+
]
160+
161+
for (const name of this.config.extensionNames) {
162+
let ext: Extension | null = null
163+
164+
try {
165+
// Check if this is a custom package path (contains ':')
166+
// Format: @org/package/path:exportedName or package/path:exportedName
167+
if (name.includes(':')) {
168+
const [packagePath, exportName] = name.split(':')
169+
if (!packagePath || !exportName) {
170+
throw new Error(
171+
`Invalid extension format '${name}'. Expected: package/path:exportedName`,
172+
)
173+
}
174+
const mod = await import(packagePath)
175+
ext = mod[exportName] as Extension
176+
if (ext) {
177+
extensions[exportName] = ext
178+
console.log(
179+
`Imported extension '${exportName}' from '${packagePath}'`,
180+
)
181+
}
182+
} else if (builtInExtensions.includes(name)) {
183+
// Built-in extension (e.g., @electric-sql/pglite/vector)
184+
const mod = await import(`@electric-sql/pglite/${name}`)
185+
ext = mod[name] as Extension
186+
if (ext) {
187+
extensions[name] = ext
188+
console.log(`Imported extension: ${name}`)
189+
}
190+
} else {
191+
// Try contrib first (e.g., @electric-sql/pglite/contrib/pgcrypto)
192+
try {
193+
const mod = await import(`@electric-sql/pglite/contrib/${name}`)
194+
ext = mod[name] as Extension
195+
} catch {
196+
// Fall back to external package (e.g., @electric-sql/pglite-<extension>)
197+
const mod = await import(`@electric-sql/pglite-${name}`)
198+
ext = mod[name] as Extension
199+
}
200+
if (ext) {
201+
extensions[name] = ext
202+
console.log(`Imported extension: ${name}`)
203+
}
204+
}
205+
} catch (error) {
206+
console.error(`Failed to import extension '${name}':`, error)
207+
throw new Error(`Failed to import extension '${name}'`)
208+
}
209+
}
210+
211+
return Object.keys(extensions).length > 0 ? extensions : undefined
212+
}
213+
129214
private async initializeDatabase(): Promise<void> {
130215
console.log(`Initializing PGLite with database: ${this.config.dbPath}`)
131216
console.log(`Debug level: ${this.config.debugLevel}`)
132217

133-
this.db = new PGlite(this.config.dbPath, { debug: this.config.debugLevel })
218+
const extensions = await this.importExtensions()
219+
220+
this.db = new PGlite(this.config.dbPath, {
221+
debug: this.config.debugLevel,
222+
extensions,
223+
})
134224
await this.db.waitReady
135225
console.log('PGlite database initialized')
136226
}

packages/pglite-socket/tests/query-with-node-pg.test.ts

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,13 @@ import {
1010
import { Client } from 'pg'
1111
import { PGlite } from '@electric-sql/pglite'
1212
import { PGLiteSocketServer } from '../src'
13+
import { spawn, ChildProcess } from 'node:child_process'
14+
import { fileURLToPath } from 'node:url'
15+
import { dirname, join } from 'node:path'
16+
import fs from 'fs'
17+
18+
const __filename = fileURLToPath(import.meta.url)
19+
const __dirname = dirname(__filename)
1320

1421
/**
1522
* Debug configuration for testing
@@ -533,4 +540,178 @@ describe(`PGLite Socket Server`, () => {
533540
expect(receivedPayload).toBe('Hello from PGlite!')
534541
})
535542
})
543+
544+
describe('with extensions via CLI', () => {
545+
const UNIX_SOCKET_DIR_PATH = `/tmp/${Date.now().toString()}`
546+
fs.mkdirSync(UNIX_SOCKET_DIR_PATH)
547+
const UNIX_SOCKET_PATH = `${UNIX_SOCKET_DIR_PATH}/.s.PGSQL.5432`
548+
let serverProcess: ChildProcess | null = null
549+
let client: typeof Client.prototype
550+
551+
beforeAll(async () => {
552+
// Start the server with extensions via CLI using tsx for dev or node for dist
553+
const serverScript = join(__dirname, '../src/scripts/server.ts')
554+
serverProcess = spawn(
555+
'npx',
556+
[
557+
'tsx',
558+
serverScript,
559+
'--path',
560+
UNIX_SOCKET_PATH,
561+
'--extensions',
562+
'vector,pg_uuidv7,@electric-sql/pglite/pg_hashids:pg_hashids',
563+
],
564+
{
565+
stdio: ['ignore', 'pipe', 'pipe'],
566+
},
567+
)
568+
569+
// Wait for server to be ready by checking for "listening" message
570+
await new Promise<void>((resolve, reject) => {
571+
const timeout = setTimeout(() => {
572+
reject(new Error('Server startup timeout'))
573+
}, 30000)
574+
575+
const onData = (data: Buffer) => {
576+
const output = data.toString()
577+
if (output.includes('listening')) {
578+
clearTimeout(timeout)
579+
resolve()
580+
}
581+
}
582+
583+
serverProcess!.stdout?.on('data', onData)
584+
serverProcess!.stderr?.on('data', (data) => {
585+
console.error('Server stderr:', data.toString())
586+
})
587+
588+
serverProcess!.on('error', (err) => {
589+
clearTimeout(timeout)
590+
reject(err)
591+
})
592+
593+
serverProcess!.on('exit', (code) => {
594+
if (code !== 0 && code !== null) {
595+
clearTimeout(timeout)
596+
reject(new Error(`Server exited with code ${code}`))
597+
}
598+
})
599+
})
600+
601+
console.log('Server with extensions started')
602+
603+
client = new Client({
604+
host: UNIX_SOCKET_DIR_PATH,
605+
database: 'postgres',
606+
user: 'postgres',
607+
password: 'postgres',
608+
connectionTimeoutMillis: 10000,
609+
})
610+
await client.connect()
611+
})
612+
613+
afterAll(async () => {
614+
if (client) {
615+
await client.end().catch(() => {})
616+
}
617+
618+
if (serverProcess) {
619+
serverProcess.kill('SIGTERM')
620+
await new Promise<void>((resolve) => {
621+
serverProcess!.on('exit', () => resolve())
622+
setTimeout(resolve, 2000) // Force resolve after 2s
623+
})
624+
}
625+
})
626+
627+
it('should load and use vector extension', async () => {
628+
// Create the extension
629+
await client.query('CREATE EXTENSION IF NOT EXISTS vector')
630+
631+
// Verify extension is loaded
632+
const extCheck = await client.query(`
633+
SELECT extname FROM pg_extension WHERE extname = 'vector'
634+
`)
635+
expect(extCheck.rows).toHaveLength(1)
636+
expect(extCheck.rows[0].extname).toBe('vector')
637+
638+
// Create a table with vector column
639+
await client.query(`
640+
CREATE TABLE test_vectors (
641+
id SERIAL PRIMARY KEY,
642+
name TEXT,
643+
vec vector(3)
644+
)
645+
`)
646+
647+
// Insert test data
648+
await client.query(`
649+
INSERT INTO test_vectors (name, vec) VALUES
650+
('test1', '[1,2,3]'),
651+
('test2', '[4,5,6]'),
652+
('test3', '[7,8,9]')
653+
`)
654+
655+
// Query with vector distance
656+
const result = await client.query(`
657+
SELECT name, vec, vec <-> '[3,1,2]' AS distance
658+
FROM test_vectors
659+
ORDER BY distance
660+
`)
661+
662+
expect(result.rows).toHaveLength(3)
663+
expect(result.rows[0].name).toBe('test1')
664+
expect(result.rows[0].vec).toBe('[1,2,3]')
665+
expect(parseFloat(result.rows[0].distance)).toBeCloseTo(2.449, 2)
666+
})
667+
668+
it('should load and use pg_uuidv7 extension', async () => {
669+
// Create the extension
670+
await client.query('CREATE EXTENSION IF NOT EXISTS pg_uuidv7')
671+
672+
// Verify extension is loaded
673+
const extCheck = await client.query(`
674+
SELECT extname FROM pg_extension WHERE extname = 'pg_uuidv7'
675+
`)
676+
expect(extCheck.rows).toHaveLength(1)
677+
expect(extCheck.rows[0].extname).toBe('pg_uuidv7')
678+
679+
// Generate a UUIDv7
680+
const result = await client.query('SELECT uuid_generate_v7() as uuid')
681+
expect(result.rows[0].uuid).toHaveLength(36)
682+
683+
// Test uuid_v7_to_timestamptz function
684+
const tsResult = await client.query(`
685+
SELECT uuid_v7_to_timestamptz('018570bb-4a7d-7c7e-8df4-6d47afd8c8fc') as ts
686+
`)
687+
const timestamp = new Date(tsResult.rows[0].ts)
688+
expect(timestamp.toISOString()).toBe('2023-01-02T04:26:40.637Z')
689+
})
690+
691+
it('should load and use pg_hashids extension from npm package path', async () => {
692+
// Create the extension
693+
await client.query('CREATE EXTENSION IF NOT EXISTS pg_hashids')
694+
695+
// Verify extension is loaded
696+
const extCheck = await client.query(`
697+
SELECT extname FROM pg_extension WHERE extname = 'pg_hashids'
698+
`)
699+
expect(extCheck.rows).toHaveLength(1)
700+
expect(extCheck.rows[0].extname).toBe('pg_hashids')
701+
702+
// Test id_encode function
703+
const result = await client.query(`
704+
SELECT id_encode(1234567, 'salt', 10, 'abcdefghijABCDEFGHIJ1234567890') as hash
705+
`)
706+
expect(result.rows[0].hash).toBeTruthy()
707+
expect(typeof result.rows[0].hash).toBe('string')
708+
709+
// Test id_decode function (round-trip)
710+
const hash = result.rows[0].hash
711+
const decodeResult = await client.query(`
712+
SELECT id_decode('${hash}', 'salt', 10, 'abcdefghijABCDEFGHIJ1234567890') as id
713+
`)
714+
expect(decodeResult.rows[0].id[0]).toBe('1234567')
715+
})
716+
})
536717
})

0 commit comments

Comments
 (0)