Skip to content

Commit d3b5c4e

Browse files
committed
wip create db
1 parent 29e35b8 commit d3b5c4e

File tree

1 file changed

+97
-35
lines changed

1 file changed

+97
-35
lines changed

Tests/AppTests/Util.swift

Lines changed: 97 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -26,21 +26,52 @@ import NIOConcurrencyHelpers
2626
private let _schemaCreated = NIOLockedValueBox<Bool>(false)
2727

2828
func setup(_ environment: Environment, resetDb: Bool = true) async throws -> Application {
29+
if !(_schemaCreated.withLockedValue { $0 }) {
30+
print("Creating initial schema...")
31+
await DotEnvFile.load(for: environment, fileio: .init(threadPool: .singleton))
32+
let testDb = Environment.get("DATABASE_NAME")!
33+
do {
34+
try await withDatabase("postgres") {
35+
try await $0.query(PostgresQuery(unsafeSQL: "DROP DATABASE IF EXISTS \(testDb) WITH (FORCE)"))
36+
try await $0.query(PostgresQuery(unsafeSQL: "CREATE DATABASE \(testDb)"))
37+
}
38+
} catch {
39+
print(String(reflecting: error))
40+
throw error
41+
}
42+
do { // ensure we re-create the schema when running the first test
43+
let app = try await Application.make(environment)
44+
try await configure(app)
45+
try await app.autoMigrate()
46+
_schemaCreated.withLockedValue { $0 = true }
47+
try await app.asyncShutdown()
48+
} catch {
49+
print(String(reflecting: error))
50+
throw error
51+
}
52+
print("Created initial schema.")
53+
}
54+
55+
if resetDb {
56+
let start = Date()
57+
defer { print("Resetting database took: \(Date().timeIntervalSince(start))s") }
58+
try await _resetDb()
59+
// try await RecentPackage.refresh(on: app.db)
60+
// try await RecentRelease.refresh(on: app.db)
61+
// try await Search.refresh(on: app.db)
62+
// try await Stats.refresh(on: app.db)
63+
// try await WeightedKeyword.refresh(on: app.db)
64+
}
65+
2966
let app = try await Application.make(environment)
3067
let host = try await configure(app)
3168

3269
// Ensure `.testing` refers to "postgres" or "localhost"
3370
precondition(["localhost", "postgres", "host.docker.internal"].contains(host),
3471
".testing must be a local db, was: \(host)")
3572

36-
app.logger.logLevel = Environment.get("LOG_LEVEL").flatMap(Logger.Level.init(rawValue:)) ?? .warning
3773

38-
if !(_schemaCreated.withLockedValue { $0 }) {
39-
// ensure we create the schema when running the first test
40-
try await app.autoMigrate()
41-
_schemaCreated.withLockedValue { $0 = true }
42-
}
43-
if resetDb { try await _resetDb(app) }
74+
app.logger.logLevel = Environment.get("LOG_LEVEL").flatMap(Logger.Level.init(rawValue:)) ?? .warning
4475

4576
// Always start with a baseline mock environment to avoid hitting live resources
4677
Current = .mock(eventLoop: app.eventLoopGroup.next())
@@ -49,39 +80,70 @@ func setup(_ environment: Environment, resetDb: Bool = true) async throws -> App
4980
}
5081

5182

52-
private let tableNamesCache: NIOLockedValueBox<[String]?> = .init(nil)
83+
import PostgresNIO
5384

54-
func _resetDb(_ app: Application) async throws {
55-
guard let db = app.db as? SQLDatabase else {
56-
fatalError("Database must be an SQLDatabase ('as? SQLDatabase' must succeed)")
57-
}
85+
func connect(to databaseName: String) throws -> PostgresClient {
86+
let host = Environment.get("DATABASE_HOST")!
87+
let port = Environment.get("DATABASE_PORT").flatMap(Int.init)!
88+
let username = Environment.get("DATABASE_USERNAME")!
89+
let password = Environment.get("DATABASE_PASSWORD")!
5890

59-
guard let tables = tableNamesCache.withLockedValue({ $0 }) else {
60-
struct Row: Decodable { var table_name: String }
61-
let tableNames = try await db.raw("""
62-
SELECT table_name FROM
63-
information_schema.tables
64-
WHERE
65-
table_schema NOT IN ('pg_catalog', 'information_schema', 'public._fluent_migrations')
66-
AND table_schema NOT LIKE 'pg_toast%'
67-
AND table_name NOT LIKE '_fluent_%'
68-
""")
69-
.all(decoding: Row.self)
70-
.map(\.table_name)
71-
tableNamesCache.withLockedValue { $0 = tableNames }
72-
try await _resetDb(app)
73-
return
74-
}
91+
let config = PostgresClient.Configuration(host: host, port: port, username: username, password: password, database: databaseName, tls: .disable)
92+
return .init(configuration: config)
93+
}
94+
95+
func withDatabase(_ databaseName: String, _ query: @escaping (PostgresClient) async throws -> Void) async throws {
96+
let client = try connect(to: databaseName)
97+
try await withThrowingTaskGroup(of: Void.self) { taskGroup in
98+
taskGroup.addTask {
99+
await client.run()
100+
}
101+
102+
try await query(client)
75103

76-
for table in tables {
77-
try await db.raw("TRUNCATE TABLE \(ident: table) CASCADE").run()
104+
taskGroup.cancelAll()
78105
}
106+
}
107+
79108

80-
try await RecentPackage.refresh(on: app.db)
81-
try await RecentRelease.refresh(on: app.db)
82-
try await Search.refresh(on: app.db)
83-
try await Stats.refresh(on: app.db)
84-
try await WeightedKeyword.refresh(on: app.db)
109+
private let tableNamesCache: NIOLockedValueBox<[String]?> = .init(nil)
110+
private let snapshotCreated = ActorIsolated(false)
111+
112+
func _resetDb() async throws {
113+
// FIXME: get this dynamically
114+
let dbName = "spi_test"
115+
let templateName = dbName + "_template"
116+
117+
try await snapshotCreated.withValue { snapshotCreated in
118+
if snapshotCreated {
119+
// delete db and re-create from snapshot
120+
print("Deleting and re-creating from snapshot...")
121+
do {
122+
try await withDatabase("postgres") { client in
123+
try await client.query(PostgresQuery(unsafeSQL: "DROP DATABASE IF EXISTS \(dbName) WITH (FORCE)"))
124+
try await client.query(PostgresQuery(unsafeSQL: "CREATE DATABASE \(dbName) TEMPLATE \(templateName)"))
125+
}
126+
} catch {
127+
print(String(reflecting: error))
128+
throw error
129+
}
130+
print("Database reset.")
131+
} else {
132+
// create snapshot
133+
print("Creating snapshot...")
134+
do {
135+
try await withDatabase("postgres") { client in
136+
try await client.query(PostgresQuery(unsafeSQL: "DROP DATABASE IF EXISTS \(templateName) WITH (FORCE)"))
137+
try await client.query(PostgresQuery(unsafeSQL: "CREATE DATABASE \(templateName) TEMPLATE \(dbName)"))
138+
}
139+
} catch {
140+
print(String(reflecting: error))
141+
throw error
142+
}
143+
snapshotCreated = true
144+
print("Snapshot created.")
145+
}
146+
}
85147
}
86148

87149

0 commit comments

Comments
 (0)