Skip to content

Commit a02ae58

Browse files
committed
Add snapshot based table reset
1 parent 8ec6f83 commit a02ae58

File tree

2 files changed

+127
-127
lines changed

2 files changed

+127
-127
lines changed

Tests/AppTests/AppTestCase.swift

Lines changed: 127 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,13 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
@testable import App
16+
17+
import NIOConcurrencyHelpers
18+
import PostgresNIO
1519
import SQLKit
1620
import XCTVapor
17-
@testable import App
21+
1822

1923
class AppTestCase: XCTestCase {
2024
var app: Application!
@@ -23,9 +27,11 @@ class AppTestCase: XCTestCase {
2327
override func setUp() async throws {
2428
try await super.setUp()
2529
app = try await setup(.testing)
26-
Current.setLogger(.init(label: "test", factory: { _ in logger }))
27-
// Silence app logging
28-
app.logger = .init(label: "noop") { _ in SwiftLogNoOpLogHandler() }
30+
}
31+
32+
func setup(_ environment: Environment) async throws -> Application {
33+
try await Self.setupDb(environment)
34+
return try await setupApp(environment)
2935
}
3036

3137
override func tearDown() async throws {
@@ -35,6 +41,98 @@ class AppTestCase: XCTestCase {
3541
}
3642

3743

44+
extension AppTestCase {
45+
46+
func setupApp(_ environment: Environment) async throws -> Application {
47+
let app = try await Application.make(environment)
48+
let host = try await configure(app)
49+
50+
// Ensure `.testing` refers to "postgres" or "localhost"
51+
precondition(["localhost", "postgres", "host.docker.internal"].contains(host),
52+
".testing must be a local db, was: \(host)")
53+
54+
// Always start with a baseline mock environment to avoid hitting live resources
55+
Current = .mock(eventLoop: app.eventLoopGroup.next())
56+
57+
Current.setLogger(.init(label: "test", factory: { _ in logger }))
58+
// Silence app logging
59+
app.logger = .init(label: "noop") { _ in SwiftLogNoOpLogHandler() }
60+
61+
return app
62+
}
63+
64+
65+
static func setupDb(_ environment: Environment) async throws {
66+
await DotEnvFile.load(for: environment, fileio: .init(threadPool: .singleton))
67+
let testDbName = Environment.get("DATABASE_NAME")!
68+
let snapshotName = testDbName + "_snapshot"
69+
70+
// Create initial db snapshot on first run
71+
try await snapshotCreated.withValue { snapshotCreated in
72+
if !snapshotCreated {
73+
try await createSchema(environment, databaseName: testDbName)
74+
try await createSnapshot(original: testDbName, snapshot: snapshotName)
75+
snapshotCreated = true
76+
}
77+
}
78+
79+
try await restoreSnapshot(original: testDbName, snapshot: snapshotName)
80+
}
81+
82+
83+
static func createSchema(_ environment: Environment, databaseName: String) async throws {
84+
do {
85+
try await withDatabase("postgres") { // Connect to `postgres` db in order to reset the test db
86+
try await $0.query(PostgresQuery(unsafeSQL: "DROP DATABASE IF EXISTS \(databaseName) WITH (FORCE)"))
87+
try await $0.query(PostgresQuery(unsafeSQL: "CREATE DATABASE \(databaseName)"))
88+
}
89+
90+
do { // Use autoMigrate to spin up the schema
91+
let app = try await Application.make(environment)
92+
app.logger = .init(label: "noop") { _ in SwiftLogNoOpLogHandler() }
93+
try await configure(app)
94+
try await app.autoMigrate()
95+
try await app.asyncShutdown()
96+
}
97+
} catch {
98+
print("Create schema failed with error: ", String(reflecting: error))
99+
throw error
100+
}
101+
}
102+
103+
104+
static func createSnapshot(original: String, snapshot: String) async throws {
105+
do {
106+
try await withDatabase("postgres") { client in
107+
try await client.query(PostgresQuery(unsafeSQL: "DROP DATABASE IF EXISTS \(snapshot) WITH (FORCE)"))
108+
try await client.query(PostgresQuery(unsafeSQL: "CREATE DATABASE \(snapshot) TEMPLATE \(original)"))
109+
}
110+
} catch {
111+
print("Create snapshot failed with error: ", String(reflecting: error))
112+
throw error
113+
}
114+
}
115+
116+
117+
static func restoreSnapshot(original: String, snapshot: String) async throws {
118+
// delete db and re-create from snapshot
119+
do {
120+
try await withDatabase("postgres") { client in
121+
try await client.query(PostgresQuery(unsafeSQL: "DROP DATABASE IF EXISTS \(original) WITH (FORCE)"))
122+
try await client.query(PostgresQuery(unsafeSQL: "CREATE DATABASE \(original) TEMPLATE \(snapshot)"))
123+
}
124+
} catch {
125+
print("Restore snapshot failed with error: ", String(reflecting: error))
126+
throw error
127+
}
128+
}
129+
130+
131+
static let snapshotCreated = ActorIsolated(false)
132+
133+
}
134+
135+
38136
extension AppTestCase {
39137
func renderSQL(_ builder: SQLSelectBuilder) -> String {
40138
renderSQL(builder.query)
@@ -69,3 +167,28 @@ extension AppTestCase {
69167
}
70168
}
71169
}
170+
171+
172+
private func connect(to databaseName: String) throws -> PostgresClient {
173+
let host = Environment.get("DATABASE_HOST")!
174+
let port = Environment.get("DATABASE_PORT").flatMap(Int.init)!
175+
let username = Environment.get("DATABASE_USERNAME")!
176+
let password = Environment.get("DATABASE_PASSWORD")!
177+
178+
let config = PostgresClient.Configuration(host: host, port: port, username: username, password: password, database: databaseName, tls: .disable)
179+
return .init(configuration: config)
180+
}
181+
182+
private func withDatabase(_ databaseName: String, _ query: @escaping (PostgresClient) async throws -> Void) async throws {
183+
let client = try connect(to: databaseName)
184+
try await withThrowingTaskGroup(of: Void.self) { taskGroup in
185+
taskGroup.addTask {
186+
await client.run()
187+
}
188+
189+
try await query(client)
190+
191+
taskGroup.cancelAll()
192+
}
193+
}
194+

Tests/AppTests/Util.swift

Lines changed: 0 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -23,129 +23,6 @@ import NIOConcurrencyHelpers
2323

2424
// MARK: - Test helpers
2525

26-
private let _schemaCreated = NIOLockedValueBox<Bool>(false)
27-
28-
func setup(_ environment: Environment, resetDb: Bool = true) async throws -> Application {
29-
if !(_schemaCreated.withLockedValue { $0 }) {
30-
print("Creating initial schema...")
31-
await DotEnvFile.load(for: environment, fileio: .init(threadPool: .singleton))
32-
let testDb = Environment.get("DATABASE_NAME")!
33-
do {
34-
try await withDatabase("postgres") {
35-
try await $0.query(PostgresQuery(unsafeSQL: "DROP DATABASE IF EXISTS \(testDb) WITH (FORCE)"))
36-
try await $0.query(PostgresQuery(unsafeSQL: "CREATE DATABASE \(testDb)"))
37-
}
38-
} catch {
39-
print(String(reflecting: error))
40-
throw error
41-
}
42-
do { // ensure we re-create the schema when running the first test
43-
let app = try await Application.make(environment)
44-
try await configure(app)
45-
try await app.autoMigrate()
46-
_schemaCreated.withLockedValue { $0 = true }
47-
try await app.asyncShutdown()
48-
} catch {
49-
print(String(reflecting: error))
50-
throw error
51-
}
52-
print("Created initial schema.")
53-
}
54-
55-
if resetDb {
56-
let start = Date()
57-
defer { print("Resetting database took: \(Date().timeIntervalSince(start))s") }
58-
try await _resetDb()
59-
// try await RecentPackage.refresh(on: app.db)
60-
// try await RecentRelease.refresh(on: app.db)
61-
// try await Search.refresh(on: app.db)
62-
// try await Stats.refresh(on: app.db)
63-
// try await WeightedKeyword.refresh(on: app.db)
64-
}
65-
66-
let app = try await Application.make(environment)
67-
let host = try await configure(app)
68-
69-
// Ensure `.testing` refers to "postgres" or "localhost"
70-
precondition(["localhost", "postgres", "host.docker.internal"].contains(host),
71-
".testing must be a local db, was: \(host)")
72-
73-
74-
app.logger.logLevel = Environment.get("LOG_LEVEL").flatMap(Logger.Level.init(rawValue:)) ?? .warning
75-
76-
// Always start with a baseline mock environment to avoid hitting live resources
77-
Current = .mock(eventLoop: app.eventLoopGroup.next())
78-
79-
return app
80-
}
81-
82-
83-
import PostgresNIO
84-
85-
func connect(to databaseName: String) throws -> PostgresClient {
86-
let host = Environment.get("DATABASE_HOST")!
87-
let port = Environment.get("DATABASE_PORT").flatMap(Int.init)!
88-
let username = Environment.get("DATABASE_USERNAME")!
89-
let password = Environment.get("DATABASE_PASSWORD")!
90-
91-
let config = PostgresClient.Configuration(host: host, port: port, username: username, password: password, database: databaseName, tls: .disable)
92-
return .init(configuration: config)
93-
}
94-
95-
func withDatabase(_ databaseName: String, _ query: @escaping (PostgresClient) async throws -> Void) async throws {
96-
let client = try connect(to: databaseName)
97-
try await withThrowingTaskGroup(of: Void.self) { taskGroup in
98-
taskGroup.addTask {
99-
await client.run()
100-
}
101-
102-
try await query(client)
103-
104-
taskGroup.cancelAll()
105-
}
106-
}
107-
108-
109-
private let tableNamesCache: NIOLockedValueBox<[String]?> = .init(nil)
110-
private let snapshotCreated = ActorIsolated(false)
111-
112-
func _resetDb() async throws {
113-
// FIXME: get this dynamically
114-
let dbName = "spi_test"
115-
let templateName = dbName + "_template"
116-
117-
try await snapshotCreated.withValue { snapshotCreated in
118-
if snapshotCreated {
119-
// delete db and re-create from snapshot
120-
print("Deleting and re-creating from snapshot...")
121-
do {
122-
try await withDatabase("postgres") { client in
123-
try await client.query(PostgresQuery(unsafeSQL: "DROP DATABASE IF EXISTS \(dbName) WITH (FORCE)"))
124-
try await client.query(PostgresQuery(unsafeSQL: "CREATE DATABASE \(dbName) TEMPLATE \(templateName)"))
125-
}
126-
} catch {
127-
print(String(reflecting: error))
128-
throw error
129-
}
130-
print("Database reset.")
131-
} else {
132-
// create snapshot
133-
print("Creating snapshot...")
134-
do {
135-
try await withDatabase("postgres") { client in
136-
try await client.query(PostgresQuery(unsafeSQL: "DROP DATABASE IF EXISTS \(templateName) WITH (FORCE)"))
137-
try await client.query(PostgresQuery(unsafeSQL: "CREATE DATABASE \(templateName) TEMPLATE \(dbName)"))
138-
}
139-
} catch {
140-
print(String(reflecting: error))
141-
throw error
142-
}
143-
snapshotCreated = true
144-
print("Snapshot created.")
145-
}
146-
}
147-
}
148-
14926

15027
func fixtureString(for fixture: String) throws -> String {
15128
String(decoding: try fixtureData(for: fixture), as: UTF8.self)

0 commit comments

Comments
 (0)