@@ -26,21 +26,52 @@ import NIOConcurrencyHelpers
26
26
private let _schemaCreated = NIOLockedValueBox < Bool > ( false )
27
27
28
28
func setup( _ environment: Environment , resetDb: Bool = true ) async throws -> Application {
29
+ if !( _schemaCreated. withLockedValue { $0 } ) {
30
+ print ( " Creating initial schema... " )
31
+ await DotEnvFile . load ( for: environment, fileio: . init( threadPool: . singleton) )
32
+ let testDb = Environment . get ( " DATABASE_NAME " ) !
33
+ do {
34
+ try await withDatabase ( " postgres " ) {
35
+ try await $0. query ( PostgresQuery ( unsafeSQL: " DROP DATABASE IF EXISTS \( testDb) WITH (FORCE) " ) )
36
+ try await $0. query ( PostgresQuery ( unsafeSQL: " CREATE DATABASE \( testDb) " ) )
37
+ }
38
+ } catch {
39
+ print ( String ( reflecting: error) )
40
+ throw error
41
+ }
42
+ do { // ensure we re-create the schema when running the first test
43
+ let app = try await Application . make ( environment)
44
+ try await configure ( app)
45
+ try await app. autoMigrate ( )
46
+ _schemaCreated. withLockedValue { $0 = true }
47
+ try await app. asyncShutdown ( )
48
+ } catch {
49
+ print ( String ( reflecting: error) )
50
+ throw error
51
+ }
52
+ print ( " Created initial schema. " )
53
+ }
54
+
55
+ if resetDb {
56
+ let start = Date ( )
57
+ defer { print ( " Resetting database took: \( Date ( ) . timeIntervalSince ( start) ) s " ) }
58
+ try await _resetDb ( )
59
+ // try await RecentPackage.refresh(on: app.db)
60
+ // try await RecentRelease.refresh(on: app.db)
61
+ // try await Search.refresh(on: app.db)
62
+ // try await Stats.refresh(on: app.db)
63
+ // try await WeightedKeyword.refresh(on: app.db)
64
+ }
65
+
29
66
let app = try await Application . make ( environment)
30
67
let host = try await configure ( app)
31
68
32
69
// Ensure `.testing` refers to "postgres" or "localhost"
33
70
precondition ( [ " localhost " , " postgres " , " host.docker.internal " ] . contains ( host) ,
34
71
" .testing must be a local db, was: \( host) " )
35
72
36
- app. logger. logLevel = Environment . get ( " LOG_LEVEL " ) . flatMap ( Logger . Level. init ( rawValue: ) ) ?? . warning
37
73
38
- if !( _schemaCreated. withLockedValue { $0 } ) {
39
- // ensure we create the schema when running the first test
40
- try await app. autoMigrate ( )
41
- _schemaCreated. withLockedValue { $0 = true }
42
- }
43
- if resetDb { try await _resetDb ( app) }
74
+ app. logger. logLevel = Environment . get ( " LOG_LEVEL " ) . flatMap ( Logger . Level. init ( rawValue: ) ) ?? . warning
44
75
45
76
// Always start with a baseline mock environment to avoid hitting live resources
46
77
Current = . mock( eventLoop: app. eventLoopGroup. next ( ) )
@@ -49,39 +80,70 @@ func setup(_ environment: Environment, resetDb: Bool = true) async throws -> App
49
80
}
50
81
51
82
52
- private let tableNamesCache : NIOLockedValueBox < [ String ] ? > = . init ( nil )
83
+ import PostgresNIO
53
84
54
- func _resetDb( _ app: Application ) async throws {
55
- guard let db = app. db as? SQLDatabase else {
56
- fatalError ( " Database must be an SQLDatabase ('as? SQLDatabase' must succeed) " )
57
- }
85
+ func connect( to databaseName: String ) throws -> PostgresClient {
86
+ let host = Environment . get ( " DATABASE_HOST " ) !
87
+ let port = Environment . get ( " DATABASE_PORT " ) . flatMap ( Int . init) !
88
+ let username = Environment . get ( " DATABASE_USERNAME " ) !
89
+ let password = Environment . get ( " DATABASE_PASSWORD " ) !
58
90
59
- guard let tables = tableNamesCache. withLockedValue ( { $0 } ) else {
60
- struct Row : Decodable { var table_name : String }
61
- let tableNames = try await db. raw ( """
62
- SELECT table_name FROM
63
- information_schema.tables
64
- WHERE
65
- table_schema NOT IN ('pg_catalog', 'information_schema', 'public._fluent_migrations')
66
- AND table_schema NOT LIKE 'pg_toast%'
67
- AND table_name NOT LIKE '_fluent_%'
68
- """ )
69
- . all ( decoding: Row . self)
70
- . map ( \. table_name)
71
- tableNamesCache. withLockedValue { $0 = tableNames }
72
- try await _resetDb ( app)
73
- return
74
- }
91
+ let config = PostgresClient . Configuration ( host: host, port: port, username: username, password: password, database: databaseName, tls: . disable)
92
+ return . init( configuration: config)
93
+ }
94
+
95
+ func withDatabase( _ databaseName: String , _ query: @escaping ( PostgresClient ) async throws -> Void ) async throws {
96
+ let client = try connect ( to: databaseName)
97
+ try await withThrowingTaskGroup ( of: Void . self) { taskGroup in
98
+ taskGroup. addTask {
99
+ await client. run ( )
100
+ }
101
+
102
+ try await query ( client)
75
103
76
- for table in tables {
77
- try await db. raw ( " TRUNCATE TABLE \( ident: table) CASCADE " ) . run ( )
104
+ taskGroup. cancelAll ( )
78
105
}
106
+ }
107
+
79
108
80
- try await RecentPackage . refresh ( on: app. db)
81
- try await RecentRelease . refresh ( on: app. db)
82
- try await Search . refresh ( on: app. db)
83
- try await Stats . refresh ( on: app. db)
84
- try await WeightedKeyword . refresh ( on: app. db)
109
+ private let tableNamesCache : NIOLockedValueBox < [ String ] ? > = . init( nil )
110
+ private let snapshotCreated = ActorIsolated ( false )
111
+
112
+ func _resetDb( ) async throws {
113
+ // FIXME: get this dynamically
114
+ let dbName = " spi_test "
115
+ let templateName = dbName + " _template "
116
+
117
+ try await snapshotCreated. withValue { snapshotCreated in
118
+ if snapshotCreated {
119
+ // delete db and re-create from snapshot
120
+ print ( " Deleting and re-creating from snapshot... " )
121
+ do {
122
+ try await withDatabase ( " postgres " ) { client in
123
+ try await client. query ( PostgresQuery ( unsafeSQL: " DROP DATABASE IF EXISTS \( dbName) WITH (FORCE) " ) )
124
+ try await client. query ( PostgresQuery ( unsafeSQL: " CREATE DATABASE \( dbName) TEMPLATE \( templateName) " ) )
125
+ }
126
+ } catch {
127
+ print ( String ( reflecting: error) )
128
+ throw error
129
+ }
130
+ print ( " Database reset. " )
131
+ } else {
132
+ // create snapshot
133
+ print ( " Creating snapshot... " )
134
+ do {
135
+ try await withDatabase ( " postgres " ) { client in
136
+ try await client. query ( PostgresQuery ( unsafeSQL: " DROP DATABASE IF EXISTS \( templateName) WITH (FORCE) " ) )
137
+ try await client. query ( PostgresQuery ( unsafeSQL: " CREATE DATABASE \( templateName) TEMPLATE \( dbName) " ) )
138
+ }
139
+ } catch {
140
+ print ( String ( reflecting: error) )
141
+ throw error
142
+ }
143
+ snapshotCreated = true
144
+ print ( " Snapshot created. " )
145
+ }
146
+ }
85
147
}
86
148
87
149
0 commit comments