|
1 | | -import { QueryResult } from "@powersync/common"; |
| 1 | +import { QueryResult } from '@powersync/common'; |
2 | 2 | import BetterSQLite3Database, { Database } from 'better-sqlite3'; |
3 | 3 | import * as Comlink from 'comlink'; |
4 | | -import { MessagePort, Worker, parentPort } from 'node:worker_threads'; |
| 4 | +import { MessagePort, parentPort, threadId } from 'node:worker_threads'; |
5 | 5 | import OS from 'node:os'; |
6 | 6 | import url from 'node:url'; |
7 | 7 |
|
8 | 8 | export type ProxiedQueryResult = Omit<QueryResult, 'rows'> & { |
9 | | - rows?: { |
10 | | - _array: any[]; |
11 | | - length: number; |
12 | | - }; |
| 9 | + rows?: { |
| 10 | + _array: any[]; |
| 11 | + length: number; |
13 | 12 | }; |
| 13 | +}; |
14 | 14 |
|
15 | 15 | export interface AsyncDatabase { |
16 | 16 | execute: (query: string, params: any[]) => Promise<ProxiedQueryResult>; |
17 | 17 | executeBatch: (query: string, params: any[][]) => Promise<ProxiedQueryResult>; |
18 | 18 | close: () => Promise<void>; |
| 19 | + // Collect table updates made since the last call to collectCommittedUpdates. |
| 20 | + // This happens on the worker because we otherwise get race conditions when wrapping |
| 21 | + // callbacks to invoke on the main thread (we need a guarantee that collectCommittedUpdates |
| 22 | + // contains entries immediately after calling COMMIT). |
| 23 | + collectCommittedUpdates: () => Promise<string[]>; |
19 | 24 | } |
20 | 25 |
|
21 | 26 | class BlockingAsyncDatabase implements AsyncDatabase { |
22 | | - private readonly db: Database |
| 27 | + private readonly db: Database; |
23 | 28 |
|
24 | | - constructor(db: Database) { |
25 | | - this.db = db; |
26 | | - } |
| 29 | + private readonly uncommittedUpdatedTables = new Set<string>(); |
| 30 | + private readonly committedUpdatedTables = new Set<string>(); |
27 | 31 |
|
28 | | - async close() { |
29 | | - this.db.close(); |
30 | | - } |
| 32 | + constructor(db: Database) { |
| 33 | + this.db = db; |
| 34 | + |
| 35 | + db.function('node_thread_id', () => threadId); |
| 36 | + } |
| 37 | + |
| 38 | + collectCommittedUpdates() { |
| 39 | + const resolved = Promise.resolve([...this.committedUpdatedTables]); |
| 40 | + this.committedUpdatedTables.clear(); |
| 41 | + return resolved; |
| 42 | + } |
| 43 | + |
| 44 | + installUpdateHooks() { |
| 45 | + this.db.updateHook((_op: string, _dbName: string, tableName: string, _rowid: bigint) => { |
| 46 | + this.uncommittedUpdatedTables.add(tableName); |
| 47 | + }); |
31 | 48 |
|
32 | | - async execute(query: string, params: any[]) { |
33 | | - const stmt = this.db.prepare(query); |
34 | | - if (stmt.reader) { |
35 | | - const rows = stmt.all(params); |
36 | | - return { |
37 | | - rowsAffected: 0, |
38 | | - rows: { |
39 | | - _array: rows, |
40 | | - length: rows.length, |
41 | | - }, |
42 | | - }; |
43 | | - } else { |
44 | | - const info = stmt.run(params); |
45 | | - return { |
46 | | - rowsAffected: info.changes, |
47 | | - insertId: Number(info.lastInsertRowid), |
48 | | - }; |
| 49 | + this.db.commitHook(() => { |
| 50 | + for (const tableName of this.uncommittedUpdatedTables) { |
| 51 | + this.committedUpdatedTables.add(tableName); |
| 52 | + } |
| 53 | + this.uncommittedUpdatedTables.clear(); |
| 54 | + return true; |
| 55 | + }); |
| 56 | + |
| 57 | + this.db.rollbackHook(() => { |
| 58 | + this.uncommittedUpdatedTables.clear(); |
| 59 | + }); |
| 60 | + } |
| 61 | + |
| 62 | + async close() { |
| 63 | + this.db.close(); |
| 64 | + } |
| 65 | + |
| 66 | + async execute(query: string, params: any[]) { |
| 67 | + const stmt = this.db.prepare(query); |
| 68 | + if (stmt.reader) { |
| 69 | + const rows = stmt.all(params); |
| 70 | + return { |
| 71 | + rowsAffected: 0, |
| 72 | + rows: { |
| 73 | + _array: rows, |
| 74 | + length: rows.length |
49 | 75 | } |
| 76 | + }; |
| 77 | + } else { |
| 78 | + const info = stmt.run(params); |
| 79 | + return { |
| 80 | + rowsAffected: info.changes, |
| 81 | + insertId: Number(info.lastInsertRowid) |
| 82 | + }; |
50 | 83 | } |
| 84 | + } |
51 | 85 |
|
52 | | - async executeBatch(query: string, params: any[][]) { |
53 | | - params = params ?? []; |
| 86 | + async executeBatch(query: string, params: any[][]) { |
| 87 | + params = params ?? []; |
54 | 88 |
|
55 | | - let rowsAffected = 0; |
56 | | - |
57 | | - const stmt = this.db.prepare(query); |
58 | | - for (const paramSet of params) { |
59 | | - const info = stmt.run(paramSet); |
60 | | - rowsAffected += info.changes; |
61 | | - } |
62 | | - |
63 | | - return { rowsAffected }; |
| 89 | + let rowsAffected = 0; |
| 90 | + |
| 91 | + const stmt = this.db.prepare(query); |
| 92 | + for (const paramSet of params) { |
| 93 | + const info = stmt.run(paramSet); |
| 94 | + rowsAffected += info.changes; |
64 | 95 | } |
| 96 | + |
| 97 | + return { rowsAffected }; |
| 98 | + } |
65 | 99 | } |
66 | 100 |
|
67 | 101 | export class BetterSqliteWorker { |
68 | | - open(path: string, isWriter: boolean): AsyncDatabase { |
69 | | - const baseDB = new BetterSQLite3Database(path); |
70 | | - baseDB.pragma('journal_mode = WAL'); |
71 | | - loadExtension(baseDB); |
72 | | - if (!isWriter) { |
73 | | - baseDB.pragma('query_only = true'); |
74 | | - } |
75 | | - |
76 | | - return Comlink.proxy(new BlockingAsyncDatabase(baseDB)); |
| 102 | + open(path: string, isWriter: boolean): AsyncDatabase { |
| 103 | + const baseDB = new BetterSQLite3Database(path); |
| 104 | + baseDB.pragma('journal_mode = WAL'); |
| 105 | + loadExtension(baseDB); |
| 106 | + if (!isWriter) { |
| 107 | + baseDB.pragma('query_only = true'); |
77 | 108 | } |
| 109 | + |
| 110 | + const asyncDb = new BlockingAsyncDatabase(baseDB); |
| 111 | + asyncDb.installUpdateHooks(); |
| 112 | + |
| 113 | + return Comlink.proxy(asyncDb); |
| 114 | + } |
78 | 115 | } |
79 | 116 |
|
80 | 117 | const platform = OS.platform(); |
81 | 118 | let extensionPath: string; |
82 | | -if (platform === "win32") { |
| 119 | +if (platform === 'win32') { |
83 | 120 | extensionPath = 'powersync.dll'; |
84 | | -} else if (platform === "linux") { |
| 121 | +} else if (platform === 'linux') { |
85 | 122 | extensionPath = 'libpowersync.so'; |
86 | | -} else if (platform === "darwin") { |
| 123 | +} else if (platform === 'darwin') { |
87 | 124 | extensionPath = 'libpowersync.dylib'; |
88 | 125 | } |
89 | 126 |
|
90 | 127 | const loadExtension = (db: Database) => { |
91 | 128 | const resolved = url.fileURLToPath(new URL(`../${extensionPath}`, import.meta.url)); |
92 | 129 | db.loadExtension(resolved, 'sqlite3_powersync_init'); |
93 | | -} |
| 130 | +}; |
94 | 131 |
|
95 | 132 | function toComlink(port: MessagePort): Comlink.Endpoint { |
96 | | - return { |
97 | | - postMessage: port.postMessage.bind(port), |
98 | | - start: port.start && port.start.bind(port), |
99 | | - addEventListener: port.addEventListener.bind(port), |
100 | | - removeEventListener: port.removeEventListener.bind(port), |
101 | | - }; |
| 133 | + return { |
| 134 | + postMessage: port.postMessage.bind(port), |
| 135 | + start: port.start && port.start.bind(port), |
| 136 | + addEventListener: port.addEventListener.bind(port), |
| 137 | + removeEventListener: port.removeEventListener.bind(port) |
| 138 | + }; |
102 | 139 | } |
103 | 140 |
|
104 | 141 | Comlink.expose(new BetterSqliteWorker(), toComlink(parentPort!)); |
0 commit comments