|
| 1 | +import * as fs from 'node:fs' |
| 2 | +import * as path from 'node:path' |
1 | 3 | import { test as testBase } from 'vitest' |
2 | 4 | import sqlite3, { type Database } from 'sqlite3' |
| 5 | +import * as databaseModule from './src/database' |
3 | 6 |
|
4 | 7 | interface Fixtures { |
5 | 8 | createMockDatabase: ( |
6 | | - init: (database: Database, done: () => void) => void, |
| 9 | + seed: (database: Database, handle: CallbackHandle) => void, |
7 | 10 | ) => Promise<void> |
8 | 11 | } |
9 | 12 |
|
| 13 | +type CallbackHandle = (error?: Error | null, ...args: Array<any>) => void |
| 14 | + |
| 15 | +function toPromise<T>(init: (handle: CallbackHandle) => T): Promise<T> { |
| 16 | + return new Promise<T>((resolve, reject) => { |
| 17 | + const result = init((error) => { |
| 18 | + if (error) { |
| 19 | + return reject(error) |
| 20 | + } |
| 21 | + resolve(result) |
| 22 | + }) |
| 23 | + }) |
| 24 | +} |
| 25 | + |
10 | 26 | export const test = testBase.extend<Fixtures>({ |
11 | 27 | createMockDatabase: [ |
12 | | - async ({ task }, use) => { |
13 | | - const db = new sqlite3.Database(':memory:') |
| 28 | + async ({ task, onTestFinished }, use) => { |
| 29 | + const dbFile = `${task.file.filepath}-${task.id}.sqlite` |
| 30 | + |
| 31 | + if (fs.existsSync(dbFile)) { |
| 32 | + await fs.promises.rm(dbFile) |
| 33 | + } |
| 34 | + |
| 35 | + const mockDatabase = await toPromise((handle) => { |
| 36 | + return new sqlite3.Database(dbFile, handle) |
| 37 | + }) |
14 | 38 |
|
15 | | - vi.doMock(import('./src/db'), async (importOriginal) => { |
16 | | - const original = await importOriginal() |
17 | | - const mockClient = new original.DatabaseClient(db) |
| 39 | + onTestFinished(async ({ task }) => { |
| 40 | + await toPromise((handle) => mockDatabase.close(handle)) |
18 | 41 |
|
19 | | - return { |
20 | | - ...original, |
21 | | - client: mockClient, |
| 42 | + if (task.type !== 'test') { |
| 43 | + return |
| 44 | + } |
| 45 | + |
| 46 | + if (task.result?.state === 'pass') { |
| 47 | + await fs.promises.rm(dbFile) |
| 48 | + } else { |
| 49 | + task.result?.errors?.push({ |
| 50 | + name: 'Mock database', |
| 51 | + message: 'See the database state:', |
| 52 | + codeFrame: path.relative(process.cwd(), dbFile), |
| 53 | + }) |
22 | 54 | } |
23 | 55 | }) |
24 | 56 |
|
25 | | - await new Promise<void>((resolve) => { |
26 | | - db.serialize(() => { |
27 | | - db.run('CREATE TABLE users (id TEXT, name TEXT)', resolve) |
28 | | - }) |
| 57 | + const clientSpy = vi |
| 58 | + .spyOn(databaseModule, 'client', 'get') |
| 59 | + .mockReturnValue(new databaseModule.DatabaseClient(mockDatabase)) |
| 60 | + |
| 61 | + await toPromise((handle) => { |
| 62 | + mockDatabase.run('CREATE TABLE users (id TEXT, name TEXT)', handle) |
29 | 63 | }) |
30 | 64 |
|
31 | | - await use((initDatabase) => { |
32 | | - return new Promise((resolve) => { |
33 | | - initDatabase(db, resolve) |
| 65 | + await use((seed) => { |
| 66 | + return toPromise((handle) => { |
| 67 | + seed(mockDatabase, handle) |
34 | 68 | }) |
35 | 69 | }) |
36 | 70 |
|
37 | | - await new Promise<void>((resolve) => { |
38 | | - db.close(() => resolve()) |
39 | | - }) |
40 | | - vi.doUnmock(import('./src/db')) |
| 71 | + clientSpy.mockRestore() |
41 | 72 | }, |
42 | 73 | { |
43 | 74 | auto: true, |
|
0 commit comments