12
12
// See the License for the specific language governing permissions and
13
13
// limitations under the License.
14
14
15
+ @testable import App
16
+
17
+ import NIOConcurrencyHelpers
18
+ import PostgresNIO
15
19
import SQLKit
16
20
import XCTVapor
17
- @ testable import App
21
+
18
22
19
23
class AppTestCase : XCTestCase {
20
24
var app : Application !
@@ -23,9 +27,11 @@ class AppTestCase: XCTestCase {
23
27
override func setUp( ) async throws {
24
28
try await super. setUp ( )
25
29
app = try await setup ( . testing)
26
- Current . setLogger ( . init( label: " test " , factory: { _ in logger } ) )
27
- // Silence app logging
28
- app. logger = . init( label: " noop " ) { _ in SwiftLogNoOpLogHandler ( ) }
30
+ }
31
+
32
+ func setup( _ environment: Environment ) async throws -> Application {
33
+ try await Self . setupDb ( environment)
34
+ return try await setupApp ( environment)
29
35
}
30
36
31
37
override func tearDown( ) async throws {
@@ -35,6 +41,98 @@ class AppTestCase: XCTestCase {
35
41
}
36
42
37
43
44
+ extension AppTestCase {
45
+
46
+ func setupApp( _ environment: Environment ) async throws -> Application {
47
+ let app = try await Application . make ( environment)
48
+ let host = try await configure ( app)
49
+
50
+ // Ensure `.testing` refers to "postgres" or "localhost"
51
+ precondition ( [ " localhost " , " postgres " , " host.docker.internal " ] . contains ( host) ,
52
+ " .testing must be a local db, was: \( host) " )
53
+
54
+ // Always start with a baseline mock environment to avoid hitting live resources
55
+ Current = . mock( eventLoop: app. eventLoopGroup. next ( ) )
56
+
57
+ Current . setLogger ( . init( label: " test " , factory: { _ in logger } ) )
58
+ // Silence app logging
59
+ app. logger = . init( label: " noop " ) { _ in SwiftLogNoOpLogHandler ( ) }
60
+
61
+ return app
62
+ }
63
+
64
+
65
+ static func setupDb( _ environment: Environment ) async throws {
66
+ await DotEnvFile . load ( for: environment, fileio: . init( threadPool: . singleton) )
67
+ let testDbName = Environment . get ( " DATABASE_NAME " ) !
68
+ let snapshotName = testDbName + " _snapshot "
69
+
70
+ // Create initial db snapshot on first run
71
+ try await snapshotCreated. withValue { snapshotCreated in
72
+ if !snapshotCreated {
73
+ try await createSchema ( environment, databaseName: testDbName)
74
+ try await createSnapshot ( original: testDbName, snapshot: snapshotName)
75
+ snapshotCreated = true
76
+ }
77
+ }
78
+
79
+ try await restoreSnapshot ( original: testDbName, snapshot: snapshotName)
80
+ }
81
+
82
+
83
+ static func createSchema( _ environment: Environment , databaseName: String ) async throws {
84
+ do {
85
+ try await withDatabase ( " postgres " ) { // Connect to `postgres` db in order to reset the test db
86
+ try await $0. query ( PostgresQuery ( unsafeSQL: " DROP DATABASE IF EXISTS \( databaseName) WITH (FORCE) " ) )
87
+ try await $0. query ( PostgresQuery ( unsafeSQL: " CREATE DATABASE \( databaseName) " ) )
88
+ }
89
+
90
+ do { // Use autoMigrate to spin up the schema
91
+ let app = try await Application . make ( environment)
92
+ app. logger = . init( label: " noop " ) { _ in SwiftLogNoOpLogHandler ( ) }
93
+ try await configure ( app)
94
+ try await app. autoMigrate ( )
95
+ try await app. asyncShutdown ( )
96
+ }
97
+ } catch {
98
+ print ( " Create schema failed with error: " , String ( reflecting: error) )
99
+ throw error
100
+ }
101
+ }
102
+
103
+
104
+ static func createSnapshot( original: String , snapshot: String ) async throws {
105
+ do {
106
+ try await withDatabase ( " postgres " ) { client in
107
+ try await client. query ( PostgresQuery ( unsafeSQL: " DROP DATABASE IF EXISTS \( snapshot) WITH (FORCE) " ) )
108
+ try await client. query ( PostgresQuery ( unsafeSQL: " CREATE DATABASE \( snapshot) TEMPLATE \( original) " ) )
109
+ }
110
+ } catch {
111
+ print ( " Create snapshot failed with error: " , String ( reflecting: error) )
112
+ throw error
113
+ }
114
+ }
115
+
116
+
117
+ static func restoreSnapshot( original: String , snapshot: String ) async throws {
118
+ // delete db and re-create from snapshot
119
+ do {
120
+ try await withDatabase ( " postgres " ) { client in
121
+ try await client. query ( PostgresQuery ( unsafeSQL: " DROP DATABASE IF EXISTS \( original) WITH (FORCE) " ) )
122
+ try await client. query ( PostgresQuery ( unsafeSQL: " CREATE DATABASE \( original) TEMPLATE \( snapshot) " ) )
123
+ }
124
+ } catch {
125
+ print ( " Restore snapshot failed with error: " , String ( reflecting: error) )
126
+ throw error
127
+ }
128
+ }
129
+
130
+
131
+ static let snapshotCreated = ActorIsolated ( false )
132
+
133
+ }
134
+
135
+
38
136
extension AppTestCase {
39
137
func renderSQL( _ builder: SQLSelectBuilder ) -> String {
40
138
renderSQL ( builder. query)
@@ -69,3 +167,28 @@ extension AppTestCase {
69
167
}
70
168
}
71
169
}
170
+
171
+
172
+ private func connect( to databaseName: String ) throws -> PostgresClient {
173
+ let host = Environment . get ( " DATABASE_HOST " ) !
174
+ let port = Environment . get ( " DATABASE_PORT " ) . flatMap ( Int . init) !
175
+ let username = Environment . get ( " DATABASE_USERNAME " ) !
176
+ let password = Environment . get ( " DATABASE_PASSWORD " ) !
177
+
178
+ let config = PostgresClient . Configuration ( host: host, port: port, username: username, password: password, database: databaseName, tls: . disable)
179
+ return . init( configuration: config)
180
+ }
181
+
182
+ private func withDatabase( _ databaseName: String , _ query: @escaping ( PostgresClient ) async throws -> Void ) async throws {
183
+ let client = try connect ( to: databaseName)
184
+ try await withThrowingTaskGroup ( of: Void . self) { taskGroup in
185
+ taskGroup. addTask {
186
+ await client. run ( )
187
+ }
188
+
189
+ try await query ( client)
190
+
191
+ taskGroup. cancelAll ( )
192
+ }
193
+ }
194
+
0 commit comments