@@ -28,12 +28,16 @@ const VALUE_COLUMN: &str = "value";
2828const VERSION_COLUMN : & str = "version" ;
2929
3030const DB_VERSION_COLUMN : & str = "db_version" ;
31+ #[ cfg( test) ]
32+ const MIGRATION_LOG_COLUMN : & str = "upgrade_from" ;
3133
3234const CHECK_DB_STMT : & str = "SELECT 1 FROM pg_database WHERE datname = $1" ;
3335const INIT_DB_CMD : & str = "CREATE DATABASE" ;
3436const GET_VERSION_STMT : & str = "SELECT db_version FROM vss_db_version;" ;
3537const UPDATE_VERSION_STMT : & str = "UPDATE vss_db_version SET db_version=$1;" ;
3638const LOG_MIGRATION_STMT : & str = "INSERT INTO vss_db_upgrades VALUES($1);" ;
39+ #[ cfg( test) ]
40+ const GET_MIGRATION_LOG_STMT : & str = "SELECT upgrade_from FROM vss_db_upgrades;" ;
3741
3842const MIGRATIONS : & [ & str ] = & [
3943 "CREATE TABLE vss_db_version (db_version INTEGER);" ,
@@ -102,6 +106,31 @@ async fn initialize_vss_database(postgres_endpoint: &str, db_name: &str) -> Resu
102106 Ok ( ( ) )
103107}
104108
109+ #[ cfg( test) ]
110+ async fn drop_database ( postgres_endpoint : & str , db_name : & str ) -> Result < ( ) , Error > {
111+ let postgres_dsn = format ! ( "{}/{}" , postgres_endpoint, "postgres" ) ;
112+ let ( client, connection) = tokio_postgres:: connect ( & postgres_dsn, NoTls )
113+ . await
114+ . map_err ( |e| Error :: new ( ErrorKind :: Other , format ! ( "Connection error: {}" , e) ) ) ?;
115+ // Connection must be driven on a separate task, and will resolve when the client is dropped
116+ tokio:: spawn ( async move {
117+ if let Err ( e) = connection. await {
118+ eprintln ! ( "Connection error: {}" , e) ;
119+ }
120+ } ) ;
121+
122+ let drop_database_statement = format ! ( "DROP DATABASE {};" , db_name) ;
123+ let num_rows = client. execute ( & drop_database_statement, & [ ] ) . await . map_err ( |e| {
124+ Error :: new (
125+ ErrorKind :: Other ,
126+ format ! ( "Failed to drop database {}: {}" , db_name, e) ,
127+ )
128+ } ) ?;
129+ assert_eq ! ( num_rows, 0 ) ;
130+
131+ Ok ( ( ) )
132+ }
133+
105134impl PostgresBackendImpl {
106135 /// Constructs a [`PostgresBackendImpl`] using `dsn` for PostgreSQL connection information.
107136 pub async fn new ( postgres_endpoint : & str , db_name : & str ) -> Result < Self , Error > {
@@ -125,12 +154,13 @@ impl PostgresBackendImpl {
125154 . map_err ( |e| Error :: new ( ErrorKind :: Other , format ! ( "Failed to build Pool: {}" , e) ) ) ?;
126155 let postgres_backend = PostgresBackendImpl { pool } ;
127156
128- postgres_backend. migrate_vss_database ( ) . await ?;
157+ #[ cfg( not( test) ) ]
158+ postgres_backend. migrate_vss_database ( MIGRATIONS ) . await ?;
129159
130160 Ok ( postgres_backend)
131161 }
132162
133- async fn migrate_vss_database ( & self ) -> Result < ( ) , Error > {
163+ async fn migrate_vss_database ( & self , migrations : & [ & str ] ) -> Result < ( usize , usize ) , Error > {
134164 let mut conn = self . pool . get ( ) . await . map_err ( |e| {
135165 Error :: new (
136166 ErrorKind :: Other ,
@@ -162,16 +192,16 @@ impl PostgresBackendImpl {
162192 . await
163193 . map_err ( |e| Error :: new ( ErrorKind :: Other , format ! ( "Transaction start error: {}" , e) ) ) ?;
164194
165- if migration_start == MIGRATIONS . len ( ) {
195+ if migration_start == migrations . len ( ) {
166196 // No migrations needed, we are done
167- return Ok ( ( ) ) ;
168- } else if migration_start > MIGRATIONS . len ( ) {
197+ return Ok ( ( migration_start , migrations . len ( ) ) ) ;
198+ } else if migration_start > migrations . len ( ) {
169199 panic ! ( "We do not allow downgrades" ) ;
170200 }
171201
172- println ! ( "Applying migration(s) {} through {}" , migration_start, MIGRATIONS . len( ) - 1 ) ;
202+ println ! ( "Applying migration(s) {} through {}" , migration_start, migrations . len( ) - 1 ) ;
173203
174- for ( idx, & stmt) in ( & MIGRATIONS [ migration_start..] ) . iter ( ) . enumerate ( ) {
204+ for ( idx, & stmt) in ( & migrations [ migration_start..] ) . iter ( ) . enumerate ( ) {
175205 let _num_rows = tx. execute ( stmt, & [ ] ) . await . map_err ( |e| {
176206 Error :: new (
177207 ErrorKind :: Other ,
@@ -197,7 +227,7 @@ impl PostgresBackendImpl {
197227 assert_eq ! ( num_rows, 1 , "LOG_MIGRATION_STMT should only add one row at a time" ) ;
198228
199229 let next_migration_start =
200- i32:: try_from ( MIGRATIONS . len ( ) ) . expect ( "Length is definitely smaller than i32::MAX" ) ;
230+ i32:: try_from ( migrations . len ( ) ) . expect ( "Length is definitely smaller than i32::MAX" ) ;
201231 let num_rows =
202232 tx. execute ( UPDATE_VERSION_STMT , & [ & next_migration_start] ) . await . map_err ( |e| {
203233 Error :: new (
@@ -214,7 +244,21 @@ impl PostgresBackendImpl {
214244 Error :: new ( ErrorKind :: Other , format ! ( "Transaction commit error: {}" , e) )
215245 } ) ?;
216246
217- Ok ( ( ) )
247+ Ok ( ( migration_start, migrations. len ( ) ) )
248+ }
249+
250+ #[ cfg( test) ]
251+ async fn get_schema_version ( & self ) -> usize {
252+ let conn = self . pool . get ( ) . await . unwrap ( ) ;
253+ let row = conn. query_one ( GET_VERSION_STMT , & [ ] ) . await . unwrap ( ) ;
254+ usize:: try_from ( row. get :: < & str , i32 > ( DB_VERSION_COLUMN ) ) . unwrap ( )
255+ }
256+
257+ #[ cfg( test) ]
258+ async fn get_upgrades_list ( & self ) -> Vec < usize > {
259+ let conn = self . pool . get ( ) . await . unwrap ( ) ;
260+ let rows = conn. query ( GET_MIGRATION_LOG_STMT , & [ ] ) . await . unwrap ( ) ;
261+ rows. iter ( ) . map ( |row| usize:: try_from ( row. get :: < & str , i32 > ( MIGRATION_LOG_COLUMN ) ) . unwrap ( ) ) . collect ( )
218262 }
219263
220264 fn build_vss_record ( & self , user_token : String , store_id : String , kv : KeyValue ) -> VssDbRecord {
@@ -568,23 +612,105 @@ mod tests {
568612 use crate :: postgres_store:: PostgresBackendImpl ;
569613 use api:: define_kv_store_tests;
570614 use tokio:: sync:: OnceCell ;
615+ use super :: { MIGRATIONS , drop_database} ;
616+
617+ const POSTGRES_ENDPOINT : & str = "postgresql://postgres:postgres@localhost:5432" ;
618+ const MIGRATIONS_START : usize = 0 ;
619+ const MIGRATIONS_END : usize = MIGRATIONS . len ( ) ;
571620
572621 static START : OnceCell < ( ) > = OnceCell :: const_new ( ) ;
573622
574623 define_kv_store_tests ! ( PostgresKvStoreTest , PostgresBackendImpl , {
624+ let db_name = "postgres_kv_store_tests" ;
575625 START
576626 . get_or_init( || async {
577- // Initialize the database once, and have other threads wait
578- PostgresBackendImpl :: new(
579- "postgresql://postgres:postgres@localhost:5432" ,
580- "postgres" ,
581- )
582- . await
583- . unwrap( ) ;
627+ let _ = drop_database( POSTGRES_ENDPOINT , db_name) . await ;
628+ let store = PostgresBackendImpl :: new( POSTGRES_ENDPOINT , db_name) . await . unwrap( ) ;
629+ let ( start, end) = store. migrate_vss_database( MIGRATIONS ) . await . unwrap( ) ;
630+ assert_eq!( start, MIGRATIONS_START ) ;
631+ assert_eq!( end, MIGRATIONS_END ) ;
584632 } )
585633 . await ;
586- PostgresBackendImpl :: new( "postgresql://postgres:postgres@localhost:5432" , "postgres" )
587- . await
588- . unwrap( )
634+ let store = PostgresBackendImpl :: new( POSTGRES_ENDPOINT , db_name) . await . unwrap( ) ;
635+ let ( start, end) = store. migrate_vss_database( MIGRATIONS ) . await . unwrap( ) ;
636+ assert_eq!( start, MIGRATIONS_END ) ;
637+ assert_eq!( end, MIGRATIONS_END ) ;
638+ assert_eq!( store. get_upgrades_list( ) . await , [ MIGRATIONS_START ] ) ;
639+ assert_eq!( store. get_schema_version( ) . await , MIGRATIONS_END ) ;
640+ store
589641 } ) ;
642+
643+ #[ tokio:: test]
644+ #[ should_panic( expected = "We do not allow downgrades" ) ]
645+ async fn panic_on_downgrade ( ) {
646+ let db_name = "panic_on_downgrade_test" ;
647+ let _ = drop_database ( POSTGRES_ENDPOINT , db_name) . await ;
648+ {
649+ let mut migrations = MIGRATIONS . to_vec ( ) ;
650+ migrations. push ( "SELECT 1 WHERE FALSE;" ) ;
651+ let store = PostgresBackendImpl :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
652+ let ( start, end) = store. migrate_vss_database ( & migrations) . await . unwrap ( ) ;
653+ assert_eq ! ( start, MIGRATIONS_START ) ;
654+ assert_eq ! ( end, MIGRATIONS_END + 1 ) ;
655+ } ;
656+ {
657+ let store = PostgresBackendImpl :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
658+ let _ = store. migrate_vss_database ( MIGRATIONS ) . await . unwrap ( ) ;
659+ } ;
660+ }
661+
662+ #[ tokio:: test]
663+ async fn new_migrations_increments_upgrades ( ) {
664+ let db_name = "new_migrations_increments_upgrades_test" ;
665+ let dummy_migration = "SELECT 1 WHERE FALSE;" ;
666+ let _ = drop_database ( POSTGRES_ENDPOINT , db_name) . await ;
667+ {
668+ let store = PostgresBackendImpl :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
669+ let ( start, end) = store. migrate_vss_database ( MIGRATIONS ) . await . unwrap ( ) ;
670+ assert_eq ! ( start, MIGRATIONS_START ) ;
671+ assert_eq ! ( end, MIGRATIONS_END ) ;
672+ assert_eq ! ( store. get_upgrades_list( ) . await , [ MIGRATIONS_START ] ) ;
673+ assert_eq ! ( store. get_schema_version( ) . await , MIGRATIONS_END ) ;
674+ } ;
675+ {
676+ let store = PostgresBackendImpl :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
677+ let ( start, end) = store. migrate_vss_database ( MIGRATIONS ) . await . unwrap ( ) ;
678+ assert_eq ! ( start, MIGRATIONS_END ) ;
679+ assert_eq ! ( end, MIGRATIONS_END ) ;
680+ assert_eq ! ( store. get_upgrades_list( ) . await , [ MIGRATIONS_START ] ) ;
681+ assert_eq ! ( store. get_schema_version( ) . await , MIGRATIONS_END ) ;
682+ } ;
683+
684+ let mut migrations = MIGRATIONS . to_vec ( ) ;
685+ migrations. push ( dummy_migration) ;
686+ {
687+ let store = PostgresBackendImpl :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
688+ let ( start, end) = store. migrate_vss_database ( & migrations) . await . unwrap ( ) ;
689+ assert_eq ! ( start, MIGRATIONS_END ) ;
690+ assert_eq ! ( end, MIGRATIONS_END + 1 ) ;
691+ assert_eq ! ( store. get_upgrades_list( ) . await , [ MIGRATIONS_START , MIGRATIONS_END ] ) ;
692+ assert_eq ! ( store. get_schema_version( ) . await , MIGRATIONS_END + 1 ) ;
693+ } ;
694+
695+ migrations. push ( dummy_migration) ;
696+ migrations. push ( dummy_migration) ;
697+ {
698+ let store = PostgresBackendImpl :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
699+ let ( start, end) = store. migrate_vss_database ( & migrations) . await . unwrap ( ) ;
700+ assert_eq ! ( start, MIGRATIONS_END + 1 ) ;
701+ assert_eq ! ( end, MIGRATIONS_END + 3 ) ;
702+ assert_eq ! ( store. get_upgrades_list( ) . await , [ MIGRATIONS_START , MIGRATIONS_END , MIGRATIONS_END + 1 ] ) ;
703+ assert_eq ! ( store. get_schema_version( ) . await , MIGRATIONS_END + 3 ) ;
704+ } ;
705+
706+ {
707+ let store = PostgresBackendImpl :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
708+ let list = store. get_upgrades_list ( ) . await ;
709+ assert_eq ! ( list, [ MIGRATIONS_START , MIGRATIONS_END , MIGRATIONS_END + 1 ] ) ;
710+ let version = store. get_schema_version ( ) . await ;
711+ assert_eq ! ( version, MIGRATIONS_END + 3 ) ;
712+ }
713+
714+ drop_database ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
715+ }
590716}
0 commit comments