@@ -21,12 +21,13 @@ import ShellOut
2121actor DatabasePool {
2222 typealias DatabaseID = UUID
2323
24+ #warning("rename to Database")
2425 struct DatabaseInfo : Hashable {
2526 var id : DatabaseID
2627 var port : Int
2728 }
2829
29- static let shared = DatabasePool ( maxCount: 4 )
30+ static let shared = DatabasePool ( maxCount: 8 )
3031
3132 var maxCount : Int
3233
@@ -63,7 +64,7 @@ actor DatabasePool {
6364 func withDatabase( _ operation: @Sendable ( DatabaseInfo) async throws -> Void ) async throws {
6465 let dbID = try await retainDatabase ( )
6566 do {
66- print ( " ⚠️ available " , availableDatabases. map ( \. port) . sorted ( ) )
67+ // print("⚠️ available", availableDatabases.map(\.port).sorted())
6768 try await operation ( dbID)
6869 try await releaseDatabase ( dbInfo: dbID)
6970 } catch {
@@ -100,8 +101,114 @@ actor DatabasePool {
100101
101102 private func removeDB( dbInfo: DatabaseInfo , maxAttempts: Int = 3 ) async throws {
102103 try await run ( maxAttempts: 3 ) { attempt in
103- print ( " ⚠️ Removing DB \( dbInfo. id) on port \( dbInfo. port) (attempt: \( attempt) ) " )
104+ // print("⚠️ Removing DB \(dbInfo.id) on port \(dbInfo.port) (attempt: \(attempt))")
104105 try await ShellOut . shellOut ( to: . removeDB( id: dbInfo. id) )
105106 }
106107 }
107108}
109+
110+
111+ import PostgresNIO
112+ import Vapor
113+
114+ extension DatabasePool . DatabaseInfo {
115+
116+ func setupDb( _ environment: Environment ) async throws {
117+ await DotEnvFile . load ( for: environment, fileio: . init( threadPool: . singleton) )
118+
119+ // Ensure DATABASE_HOST is from a restricted set db hostnames and nothing else.
120+ // This is safeguard against accidental inheritance of setup in QueryPerformanceTests
121+ // and to ensure the database resetting cannot impact any other network hosts.
122+ let host = Environment . get ( " DATABASE_HOST " ) !
123+ precondition ( [ " localhost " , " postgres " , " host.docker.internal " ] . contains ( host) ,
124+ " DATABASE_HOST must be a local db, was: \( host) " )
125+
126+ let testDbName = Environment . get ( " DATABASE_NAME " ) !
127+ let snapshotName = testDbName + " _snapshot "
128+
129+ // Create initial db snapshot
130+ try await createSchema ( environment, databaseName: testDbName)
131+ try await createSnapshot ( original: testDbName, snapshot: snapshotName, environment: environment)
132+
133+ try await restoreSnapshot ( original: testDbName, snapshot: snapshotName, environment: environment)
134+ }
135+
136+ func createSchema( _ environment: Environment , databaseName: String ) async throws {
137+ do {
138+ try await _withDatabase ( " postgres " , port: port, environment) { // Connect to `postgres` db in order to reset the test db
139+ try await $0. query ( PostgresQuery ( unsafeSQL: " DROP DATABASE IF EXISTS \( databaseName) WITH (FORCE) " ) )
140+ try await $0. query ( PostgresQuery ( unsafeSQL: " CREATE DATABASE \( databaseName) " ) )
141+ }
142+
143+ do { // Use autoMigrate to spin up the schema
144+ let app = try await Application . make ( environment)
145+ app. logger = . init( label: " noop " ) { _ in SwiftLogNoOpLogHandler ( ) }
146+ try await configure ( app, databasePort: port)
147+ try await app. autoMigrate ( )
148+ try await app. asyncShutdown ( )
149+ }
150+ } catch {
151+ print ( " Create schema failed with error: " , String ( reflecting: error) )
152+ throw error
153+ }
154+ }
155+
156+ func createSnapshot( original: String , snapshot: String , environment: Environment ) async throws {
157+ do {
158+ try await _withDatabase ( " postgres " , port: port, environment) { client in
159+ try await client. query ( PostgresQuery ( unsafeSQL: " DROP DATABASE IF EXISTS \( snapshot) WITH (FORCE) " ) )
160+ try await client. query ( PostgresQuery ( unsafeSQL: " CREATE DATABASE \( snapshot) TEMPLATE \( original) " ) )
161+ }
162+ } catch {
163+ print ( " Create snapshot failed with error: " , String ( reflecting: error) )
164+ throw error
165+ }
166+ }
167+
168+ func restoreSnapshot( original: String ,
169+ snapshot: String ,
170+ environment: Environment ) async throws {
171+ // delete db and re-create from snapshot
172+ do {
173+ try await _withDatabase ( " postgres " , port: port, environment) { client in
174+ try await client. query ( PostgresQuery ( unsafeSQL: " DROP DATABASE IF EXISTS \( original) WITH (FORCE) " ) )
175+ try await client. query ( PostgresQuery ( unsafeSQL: " CREATE DATABASE \( original) TEMPLATE \( snapshot) " ) )
176+ }
177+ } catch {
178+ print ( " Restore snapshot failed with error: " , String ( reflecting: error) )
179+ throw error
180+ }
181+ }
182+
183+ }
184+
185+
186+ private func connect( to databaseName: String ,
187+ port: Int ,
188+ _ environment: Environment ) async throws -> PostgresClient {
189+ #warning("don't load dot file, just pass in host, port, username, password tuple - or make this a method and the other values properties")
190+ await DotEnvFile . load ( for: environment, fileio: . init( threadPool: . singleton) )
191+ let host = Environment . get ( " DATABASE_HOST " ) !
192+ let username = Environment . get ( " DATABASE_USERNAME " ) !
193+ let password = Environment . get ( " DATABASE_PASSWORD " ) !
194+
195+ let config = PostgresClient . Configuration ( host: host, port: port, username: username, password: password, database: databaseName, tls: . disable)
196+
197+ return . init( configuration: config)
198+ }
199+
200+
201+ private func _withDatabase( _ databaseName: String ,
202+ port: Int ,
203+ _ environment: Environment ,
204+ _ query: @escaping ( PostgresClient ) async throws -> Void ) async throws {
205+ let client = try await connect ( to: databaseName, port: port, environment)
206+ try await withThrowingTaskGroup ( of: Void . self) { taskGroup in
207+ taskGroup. addTask { await client. run ( ) }
208+
209+ try await query ( client)
210+
211+ taskGroup. cancelAll ( )
212+ }
213+ }
214+
0 commit comments