Skip to content

Commit c365440

Browse files
committed
Make sure the postgres helpers are safer to use by always loading env variables before using Environment.get
1 parent a9ae265 commit c365440

File tree

1 file changed

+14
-13
lines changed

1 file changed

+14
-13
lines changed

Tests/AppTests/AppTestCase.swift

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -72,18 +72,18 @@ extension AppTestCase {
7272
try await snapshotCreated.withValue { snapshotCreated in
7373
if !snapshotCreated {
7474
try await createSchema(environment, databaseName: testDbName)
75-
try await createSnapshot(original: testDbName, snapshot: snapshotName)
75+
try await createSnapshot(original: testDbName, snapshot: snapshotName, environment: environment)
7676
snapshotCreated = true
7777
}
7878
}
7979

80-
try await restoreSnapshot(original: testDbName, snapshot: snapshotName)
80+
try await restoreSnapshot(original: testDbName, snapshot: snapshotName, environment: environment)
8181
}
8282

8383

8484
static func createSchema(_ environment: Environment, databaseName: String) async throws {
8585
do {
86-
try await withDatabase("postgres") { // Connect to `postgres` db in order to reset the test db
86+
try await withDatabase("postgres", .testing) { // Connect to `postgres` db in order to reset the test db
8787
try await $0.query(PostgresQuery(unsafeSQL: "DROP DATABASE IF EXISTS \(databaseName) WITH (FORCE)"))
8888
try await $0.query(PostgresQuery(unsafeSQL: "CREATE DATABASE \(databaseName)"))
8989
}
@@ -102,9 +102,9 @@ extension AppTestCase {
102102
}
103103

104104

105-
static func createSnapshot(original: String, snapshot: String) async throws {
105+
static func createSnapshot(original: String, snapshot: String, environment: Environment) async throws {
106106
do {
107-
try await withDatabase("postgres") { client in
107+
try await withDatabase("postgres", environment) { client in
108108
try await client.query(PostgresQuery(unsafeSQL: "DROP DATABASE IF EXISTS \(snapshot) WITH (FORCE)"))
109109
try await client.query(PostgresQuery(unsafeSQL: "CREATE DATABASE \(snapshot) TEMPLATE \(original)"))
110110
}
@@ -115,10 +115,10 @@ extension AppTestCase {
115115
}
116116

117117

118-
static func restoreSnapshot(original: String, snapshot: String) async throws {
118+
static func restoreSnapshot(original: String, snapshot: String, environment: Environment) async throws {
119119
// delete db and re-create from snapshot
120120
do {
121-
try await withDatabase("postgres") { client in
121+
try await withDatabase("postgres", environment) { client in
122122
try await client.query(PostgresQuery(unsafeSQL: "DROP DATABASE IF EXISTS \(original) WITH (FORCE)"))
123123
try await client.query(PostgresQuery(unsafeSQL: "CREATE DATABASE \(original) TEMPLATE \(snapshot)"))
124124
}
@@ -170,22 +170,23 @@ extension AppTestCase {
170170
}
171171

172172

173-
private func connect(to databaseName: String) throws -> PostgresClient {
173+
private func connect(to databaseName: String, _ environment: Environment) async throws -> PostgresClient {
174+
await DotEnvFile.load(for: environment, fileio: .init(threadPool: .singleton))
174175
let host = Environment.get("DATABASE_HOST")!
175176
let port = Environment.get("DATABASE_PORT").flatMap(Int.init)!
176177
let username = Environment.get("DATABASE_USERNAME")!
177178
let password = Environment.get("DATABASE_PASSWORD")!
178179

179180
let config = PostgresClient.Configuration(host: host, port: port, username: username, password: password, database: databaseName, tls: .disable)
181+
180182
return .init(configuration: config)
181183
}
182184

183-
private func withDatabase(_ databaseName: String, _ query: @escaping (PostgresClient) async throws -> Void) async throws {
184-
let client = try connect(to: databaseName)
185+
186+
private func withDatabase(_ databaseName: String, _ environment: Environment, _ query: @escaping (PostgresClient) async throws -> Void) async throws {
187+
let client = try await connect(to: databaseName, environment)
185188
try await withThrowingTaskGroup(of: Void.self) { taskGroup in
186-
taskGroup.addTask {
187-
await client.run()
188-
}
189+
taskGroup.addTask { await client.run() }
189190

190191
try await query(client)
191192

0 commit comments

Comments
 (0)