Skip to content

Commit 42a6294

Browse files
committed
Added watch function to drizzle db API.
1 parent 7e23d65 commit 42a6294

File tree

3 files changed

+348
-3
lines changed

3 files changed

+348
-3
lines changed

.changeset/calm-baboons-worry.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'@powersync/drizzle-driver': minor
3+
---
4+
5+
Added `watch()` function to support watched queries. This function invokes `execute()` on the Drizzle query which improves support for complex queries such as those which are relational.

packages/drizzle-driver/src/sqlite/db.ts

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,11 @@
1-
import { AbstractPowerSyncDatabase, QueryResult } from '@powersync/common';
1+
import {
2+
AbstractPowerSyncDatabase,
3+
QueryResult,
4+
runOnSchemaChange,
5+
SQLWatchOptions,
6+
WatchHandler
7+
} from '@powersync/common';
8+
import { Query } from 'drizzle-orm';
29
import { DefaultLogger } from 'drizzle-orm/logger';
310
import {
411
createTableRelationsHelpers,
@@ -13,6 +20,8 @@ import { SQLiteAsyncDialect } from 'drizzle-orm/sqlite-core/dialect';
1320
import type { DrizzleConfig } from 'drizzle-orm/utils';
1421
import { PowerSyncSQLiteSession, PowerSyncSQLiteTransactionConfig } from './sqlite-session';
1522

23+
type WatchQuery = { toSQL(): Query; execute(): Promise<any> };
24+
1625
export interface PowerSyncSQLiteDatabase<TSchema extends Record<string, unknown> = Record<string, never>>
1726
extends BaseSQLiteDatabase<'async', QueryResult, TSchema> {
1827
transaction<T>(
@@ -21,13 +30,15 @@ export interface PowerSyncSQLiteDatabase<TSchema extends Record<string, unknown>
2130
) => Promise<T>,
2231
config?: PowerSyncSQLiteTransactionConfig
2332
): Promise<T>;
33+
34+
watch(query: WatchQuery, handler?: WatchHandler, options?: SQLWatchOptions): void;
2435
}
2536

2637
export function wrapPowerSyncWithDrizzle<TSchema extends Record<string, unknown> = Record<string, never>>(
2738
db: AbstractPowerSyncDatabase,
2839
config: DrizzleConfig<TSchema> = {}
2940
): PowerSyncSQLiteDatabase<TSchema> {
30-
const dialect = new SQLiteAsyncDialect({casing: config.casing});
41+
const dialect = new SQLiteAsyncDialect({ casing: config.casing });
3142
let logger;
3243
if (config.logger === true) {
3344
logger = new DefaultLogger();
@@ -48,5 +59,51 @@ export function wrapPowerSyncWithDrizzle<TSchema extends Record<string, unknown>
4859
const session = new PowerSyncSQLiteSession(db, dialect, schema, {
4960
logger
5061
});
51-
return new BaseSQLiteDatabase('async', dialect, session, schema) as PowerSyncSQLiteDatabase<TSchema>;
62+
63+
const watch = (query: WatchQuery, handler?: WatchHandler, options?: SQLWatchOptions): void => {
64+
const { onResult, onError = (e: Error) => {} } = handler ?? {};
65+
if (!onResult) {
66+
throw new Error('onResult is required');
67+
}
68+
69+
const watchQuery = async (abortSignal: AbortSignal) => {
70+
try {
71+
const toSql = query.toSQL();
72+
const resolvedTables = await db.resolveTables(toSql.sql, toSql.params, options);
73+
74+
// Fetch initial data
75+
const result = await query.execute();
76+
onResult(result);
77+
78+
db.onChangeWithCallback(
79+
{
80+
onChange: async () => {
81+
try {
82+
const result = await query.execute();
83+
onResult(result);
84+
} catch (error: any) {
85+
onError(error);
86+
}
87+
},
88+
onError
89+
},
90+
{
91+
...(options ?? {}),
92+
tables: resolvedTables,
93+
// Override the abort signal since we intercept it
94+
signal: abortSignal
95+
}
96+
);
97+
} catch (error: any) {
98+
onError(error);
99+
}
100+
};
101+
102+
runOnSchemaChange(watchQuery, db, options);
103+
};
104+
105+
const baseDatabase = new BaseSQLiteDatabase('async', dialect, session, schema) as PowerSyncSQLiteDatabase<TSchema>;
106+
return Object.assign(baseDatabase, {
107+
watch: (query: WatchQuery, handler?: WatchHandler, options?: SQLWatchOptions) => watch(query, handler, options)
108+
});
52109
}
Lines changed: 283 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,283 @@
1+
import { AbstractPowerSyncDatabase, column, Schema, Table } from '@powersync/common';
2+
import { PowerSyncDatabase } from '@powersync/web';
3+
import { count, eq, sql } from 'drizzle-orm';
4+
import { integer, sqliteTable, text, uniqueIndex } from 'drizzle-orm/sqlite-core';
5+
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
6+
import * as SUT from '../../src/sqlite/db';
7+
8+
vi.useRealTimers();
9+
10+
const assetsPs = new Table(
11+
{
12+
created_at: column.text,
13+
make: column.text,
14+
model: column.text,
15+
serial_number: column.text,
16+
quantity: column.integer,
17+
user_id: column.text,
18+
customer_id: column.text,
19+
description: column.text
20+
},
21+
{ indexes: { makemodel: ['make, model'] } }
22+
);
23+
24+
const customersPs = new Table({
25+
name: column.text,
26+
email: column.text
27+
});
28+
29+
const PsSchema = new Schema({ assets: assetsPs, customers: customersPs });
30+
31+
const assets = sqliteTable(
32+
'assets',
33+
{
34+
id: text('id'),
35+
created_at: text('created_at'),
36+
make: text('make'),
37+
model: text('model'),
38+
serial_number: text('serial_number'),
39+
quantity: integer('quantity'),
40+
user_id: text('user_id'),
41+
customer_id: text('customer_id'),
42+
description: text('description')
43+
},
44+
(table) => ({
45+
makemodelIndex: uniqueIndex('makemodel').on(table.make, table.model)
46+
})
47+
);
48+
49+
const customers = sqliteTable('customers', {
50+
id: text('id'),
51+
name: text('name'),
52+
email: text('email')
53+
});
54+
55+
const DrizzleSchema = { assets, customers };
56+
57+
/**
58+
* There seems to be an issue with Vitest browser mode's setTimeout and
59+
* fake timer functionality.
60+
* e.g. calling:
61+
* await new Promise<void>((resolve) => setTimeout(resolve, 10));
62+
* waits for 1 second instead of 10ms.
63+
* Setting this to 1 second as a work around.
64+
*/
65+
const throttleDuration = 1000;
66+
67+
describe('Watch Tests', () => {
68+
let powerSyncDb: AbstractPowerSyncDatabase;
69+
let db: SUT.PowerSyncSQLiteDatabase<typeof DrizzleSchema>;
70+
71+
beforeEach(async () => {
72+
powerSyncDb = new PowerSyncDatabase({
73+
database: {
74+
dbFilename: 'test.db'
75+
},
76+
schema: PsSchema
77+
});
78+
db = SUT.wrapPowerSyncWithDrizzle(powerSyncDb, { schema: DrizzleSchema, logger: { logQuery: () => {} } });
79+
80+
await powerSyncDb.init();
81+
});
82+
83+
afterEach(async () => {
84+
await powerSyncDb.disconnectAndClear();
85+
});
86+
87+
it('watch outside throttle limits', async () => {
88+
const abortController = new AbortController();
89+
90+
const updatesCount = 2;
91+
let receivedUpdatesCount = 0;
92+
93+
/**
94+
* Promise which resolves once we received the same amount of update
95+
* notifications as there are inserts.
96+
*/
97+
const receivedUpdates = new Promise<void>((resolve) => {
98+
const onUpdate = () => {
99+
receivedUpdatesCount++;
100+
101+
if (receivedUpdatesCount == updatesCount) {
102+
abortController.abort();
103+
resolve();
104+
}
105+
};
106+
107+
const query = db
108+
.select({ count: count() })
109+
.from(assets)
110+
.innerJoin(customers, eq(customers.id, assets.customer_id));
111+
112+
db.watch(query, { onResult: onUpdate }, { signal: abortController.signal, throttleMs: throttleDuration });
113+
});
114+
115+
for (let updateCount = 0; updateCount < updatesCount; updateCount++) {
116+
await db
117+
.insert(assets)
118+
.values({
119+
id: sql`uuid()`,
120+
make: 'test',
121+
customer_id: sql`uuid()`
122+
})
123+
.execute();
124+
125+
// Wait the throttle duration, ensuring a watch update for each insert
126+
await new Promise<void>((resolve) => setTimeout(resolve, throttleDuration));
127+
}
128+
129+
await receivedUpdates;
130+
expect(receivedUpdatesCount).equals(updatesCount);
131+
});
132+
133+
it('watch inside throttle limits', async () => {
134+
const abortController = new AbortController();
135+
136+
const updatesCount = 5;
137+
let receivedUpdatesCount = 0;
138+
139+
const onUpdate = () => {
140+
receivedUpdatesCount++;
141+
};
142+
const query = db.select({ count: count() }).from(assets).innerJoin(customers, eq(customers.id, assets.customer_id));
143+
db.watch(query, { onResult: onUpdate }, { signal: abortController.signal, throttleMs: throttleDuration });
144+
145+
// Create the inserts as fast as possible
146+
for (let updateCount = 0; updateCount < updatesCount; updateCount++) {
147+
await db
148+
.insert(assets)
149+
.values({
150+
id: sql`uuid()`,
151+
make: 'test',
152+
customer_id: sql`uuid()`
153+
})
154+
.execute();
155+
}
156+
157+
await new Promise<void>((resolve) => setTimeout(resolve, throttleDuration * 2));
158+
abortController.abort();
159+
160+
// There should be one initial result plus one throttled result
161+
expect(receivedUpdatesCount).equals(2);
162+
});
163+
164+
it('should only watch tables inside query', async () => {
165+
const assetsAbortController = new AbortController();
166+
167+
let receivedAssetsUpdatesCount = 0;
168+
const onWatchAssets = () => {
169+
receivedAssetsUpdatesCount++;
170+
};
171+
172+
const queryAssets = db.select({ count: count() }).from(assets);
173+
174+
db.watch(
175+
queryAssets,
176+
{ onResult: onWatchAssets },
177+
{
178+
signal: assetsAbortController.signal
179+
}
180+
);
181+
182+
const customersAbortController = new AbortController();
183+
184+
let receivedCustomersUpdatesCount = 0;
185+
const onWatchCustomers = () => {
186+
receivedCustomersUpdatesCount++;
187+
};
188+
189+
const queryCustomers = db.select({ count: count() }).from(customers);
190+
db.watch(
191+
queryCustomers,
192+
{ onResult: onWatchCustomers },
193+
{
194+
signal: customersAbortController.signal
195+
}
196+
);
197+
198+
// Ensures insert doesn't form part of initial result
199+
await new Promise<void>((resolve) => setTimeout(resolve, throttleDuration));
200+
201+
await db
202+
.insert(assets)
203+
.values({
204+
id: sql`uuid()`,
205+
make: 'test',
206+
customer_id: sql`uuid()`
207+
})
208+
.execute();
209+
210+
await new Promise<void>((resolve) => setTimeout(resolve, throttleDuration * 2));
211+
assetsAbortController.abort();
212+
customersAbortController.abort();
213+
214+
// There should be one initial result plus one throttled result
215+
expect(receivedAssetsUpdatesCount).equals(2);
216+
217+
// Only the initial result should have yielded.
218+
expect(receivedCustomersUpdatesCount).equals(1);
219+
});
220+
221+
it('should handle watch onError', async () => {
222+
const abortController = new AbortController();
223+
const onResult = () => {}; // no-op
224+
let receivedErrorCount = 0;
225+
226+
const receivedError = new Promise<void>(async (resolve) => {
227+
const onError = () => {
228+
receivedErrorCount++;
229+
resolve();
230+
};
231+
232+
const query = db
233+
.select({
234+
id: sql`fakeFunction()` // Simulate an error with invalid function
235+
})
236+
.from(assets);
237+
238+
db.watch(query, { onResult, onError }, { signal: abortController.signal, throttleMs: throttleDuration });
239+
});
240+
abortController.abort();
241+
242+
await receivedError;
243+
expect(receivedErrorCount).equals(1);
244+
});
245+
246+
it('should throttle watch overflow', async () => {
247+
const overflowAbortController = new AbortController();
248+
const updatesCount = 25;
249+
250+
let receivedWithManagedOverflowCount = 0;
251+
const firstResultReceived = new Promise<void>((resolve) => {
252+
const onResultOverflow = () => {
253+
if (receivedWithManagedOverflowCount === 0) {
254+
resolve();
255+
}
256+
receivedWithManagedOverflowCount++;
257+
};
258+
const query = db.select({ count: count() }).from(assets);
259+
db.watch(query, { onResult: onResultOverflow }, { signal: overflowAbortController.signal, throttleMs: 1 });
260+
});
261+
262+
await firstResultReceived;
263+
264+
// Perform a large number of inserts to trigger overflow
265+
for (let i = 0; i < updatesCount; i++) {
266+
db.insert(assets)
267+
.values({
268+
id: sql`uuid()`,
269+
make: 'test',
270+
customer_id: sql`uuid()`
271+
})
272+
.execute();
273+
}
274+
275+
await new Promise<void>((resolve) => setTimeout(resolve, 1 * throttleDuration));
276+
277+
overflowAbortController.abort();
278+
279+
// This fluctuates between 3 and 4 based on timing, but should never be 25
280+
expect(receivedWithManagedOverflowCount).greaterThan(2);
281+
expect(receivedWithManagedOverflowCount).toBeLessThanOrEqual(4);
282+
});
283+
});

0 commit comments

Comments
 (0)