1515import Foundation
1616
1717import App
18+ import PostgresNIO
1819import ShellOut
20+ import Vapor
21+
1922
2023
2124actor DatabasePool {
@@ -40,7 +43,9 @@ actor DatabasePool {
4043 try await withThrowingTaskGroup ( of: Database . self) { group in
4144 for _ in ( 0 ..< maxCount) {
4245 group. addTask {
43- try await self . launchDB ( )
46+ let db = try await self . launchDB ( )
47+ try await db. setup ( for: . testing)
48+ return db
4449 }
4550 }
4651 for try await info in group {
@@ -107,34 +112,40 @@ actor DatabasePool {
107112}
108113
109114
110- import PostgresNIO
111- import Vapor
112-
113115extension DatabasePool . Database {
114116
115- func setupDb( _ environment: Environment ) async throws {
116- await DotEnvFile . load ( for: environment, fileio: . init( threadPool: . singleton) )
117-
118- // Ensure DATABASE_HOST is from a restricted set db hostnames and nothing else.
119- // This is safeguard against accidental inheritance of setup in QueryPerformanceTests
120- // and to ensure the database resetting cannot impact any other network hosts.
121- let host = Environment . get ( " DATABASE_HOST " ) !
122- precondition ( [ " localhost " , " postgres " , " host.docker.internal " ] . contains ( host) ,
123- " DATABASE_HOST must be a local db, was: \( host) " )
117+ struct ConnectionDetails {
118+ var host : String
119+ var port : Int
120+ var username : String
121+ var password : String
122+
123+ init ( with environment: Environment , port: Int ) {
124+ // Ensure DATABASE_HOST is from a restricted set db hostnames and nothing else.
125+ // This is safeguard against accidental inheritance of setup in QueryPerformanceTests
126+ // and to ensure the database resetting cannot impact any other network hosts.
127+ self . host = Environment . get ( " DATABASE_HOST " ) !
128+ precondition ( [ " localhost " , " postgres " , " host.docker.internal " ] . contains ( host) ,
129+ " DATABASE_HOST must be a local db, was: \( host) " )
130+ self . port = port
131+ self . username = Environment . get ( " DATABASE_USERNAME " ) !
132+ self . password = Environment . get ( " DATABASE_PASSWORD " ) !
133+ }
134+ }
124135
125- let testDbName = Environment . get ( " DATABASE_NAME " ) !
126- let snapshotName = testDbName + " _snapshot "
136+ func setup( for environment: Environment ) async throws {
137+ await DotEnvFile . load ( for: environment, fileio: . init( threadPool: . singleton) )
138+ let details = ConnectionDetails ( with: environment, port: port)
127139
128140 // Create initial db snapshot
129- try await createSchema ( environment, databaseName: testDbName)
130- try await createSnapshot ( original: testDbName, snapshot: snapshotName, environment: environment)
131-
132- try await restoreSnapshot ( original: testDbName, snapshot: snapshotName, environment: environment)
141+ try await createSchema ( environment, details: details)
142+ try await createSnapshot ( details: details)
133143 }
134144
135- func createSchema( _ environment: Environment , databaseName : String ) async throws {
145+ func createSchema( _ environment: Environment , details : ConnectionDetails ) async throws {
136146 do {
137- try await _withDatabase ( " postgres " , port: port, environment) { // Connect to `postgres` db in order to reset the test db
147+ try await _withDatabase ( " postgres " , details: details) { // Connect to `postgres` db in order to reset the test db
148+ let databaseName = Environment . get ( " DATABASE_NAME " ) !
138149 try await $0. query ( PostgresQuery ( unsafeSQL: " DROP DATABASE IF EXISTS \( databaseName) WITH (FORCE) " ) )
139150 try await $0. query ( PostgresQuery ( unsafeSQL: " CREATE DATABASE \( databaseName) " ) )
140151 }
@@ -152,9 +163,11 @@ extension DatabasePool.Database {
152163 }
153164 }
154165
155- func createSnapshot( original: String , snapshot: String , environment: Environment ) async throws {
166+ func createSnapshot( details: ConnectionDetails ) async throws {
167+ let original = Environment . get ( " DATABASE_NAME " ) !
168+ let snapshot = original + " _snapshot "
156169 do {
157- try await _withDatabase ( " postgres " , port : port , environment ) { client in
170+ try await _withDatabase ( " postgres " , details : details ) { client in
158171 try await client. query ( PostgresQuery ( unsafeSQL: " DROP DATABASE IF EXISTS \( snapshot) WITH (FORCE) " ) )
159172 try await client. query ( PostgresQuery ( unsafeSQL: " CREATE DATABASE \( snapshot) TEMPLATE \( original) " ) )
160173 }
@@ -164,12 +177,12 @@ extension DatabasePool.Database {
164177 }
165178 }
166179
167- func restoreSnapshot( original : String ,
168- snapshot : String ,
169- environment : Environment ) async throws {
180+ func restoreSnapshot( details : ConnectionDetails ) async throws {
181+ let original = Environment . get ( " DATABASE_NAME " ) !
182+ let snapshot = original + " _snapshot "
170183 // delete db and re-create from snapshot
171184 do {
172- try await _withDatabase ( " postgres " , port : port , environment ) { client in
185+ try await _withDatabase ( " postgres " , details : details ) { client in
173186 try await client. query ( PostgresQuery ( unsafeSQL: " DROP DATABASE IF EXISTS \( original) WITH (FORCE) " ) )
174187 try await client. query ( PostgresQuery ( unsafeSQL: " CREATE DATABASE \( original) TEMPLATE \( snapshot) " ) )
175188 }
@@ -182,26 +195,24 @@ extension DatabasePool.Database {
182195}
183196
184197
185- private func connect( to databaseName: String ,
186- port: Int ,
187- _ environment: Environment ) async throws -> PostgresClient {
188- #warning("don't load dot file, just pass in host, port, username, password tuple - or make this a method and the other values properties")
189- await DotEnvFile . load ( for: environment, fileio: . init( threadPool: . singleton) )
190- let host = Environment . get ( " DATABASE_HOST " ) !
191- let username = Environment . get ( " DATABASE_USERNAME " ) !
192- let password = Environment . get ( " DATABASE_PASSWORD " ) !
193-
194- let config = PostgresClient . Configuration ( host: host, port: port, username: username, password: password, database: databaseName, tls: . disable)
198+ private func connect( to databaseName: String , details: DatabasePool . Database . ConnectionDetails ) async throws -> PostgresClient {
199+ let config = PostgresClient . Configuration (
200+ host: details. host,
201+ port: details. port,
202+ username: details. username,
203+ password: details. password,
204+ database: databaseName,
205+ tls: . disable
206+ )
195207
196208 return . init( configuration: config)
197209}
198210
199211
200212private func _withDatabase( _ databaseName: String ,
201- port: Int ,
202- _ environment: Environment ,
213+ details: DatabasePool . Database . ConnectionDetails ,
203214 _ query: @escaping ( PostgresClient ) async throws -> Void ) async throws {
204- let client = try await connect ( to: databaseName, port : port , environment )
215+ let client = try await connect ( to: databaseName, details : details )
205216 try await withThrowingTaskGroup ( of: Void . self) { taskGroup in
206217 taskGroup. addTask { await client. run ( ) }
207218
0 commit comments