@@ -28,12 +28,18 @@ const VALUE_COLUMN: &str = "value";
2828const VERSION_COLUMN : & str = "version" ;
2929
3030const DB_VERSION_COLUMN : & str = "db_version" ;
31+ #[ cfg( test) ]
32+ const MIGRATION_LOG_COLUMN : & str = "upgrade_from" ;
3133
3234const CHECK_DB_STMT : & str = "SELECT 1 FROM pg_database WHERE datname = $1" ;
3335const INIT_DB_CMD : & str = "CREATE DATABASE" ;
36+ #[ cfg( test) ]
37+ const DROP_DB_CMD : & str = "DROP DATABASE" ;
3438const GET_VERSION_STMT : & str = "SELECT db_version FROM vss_db_version;" ;
3539const UPDATE_VERSION_STMT : & str = "UPDATE vss_db_version SET db_version=$1;" ;
3640const LOG_MIGRATION_STMT : & str = "INSERT INTO vss_db_upgrades VALUES($1);" ;
41+ #[ cfg( test) ]
42+ const GET_MIGRATION_LOG_STMT : & str = "SELECT upgrade_from FROM vss_db_upgrades;" ;
3743
3844const MIGRATIONS : & [ & str ] = & [
3945 "CREATE TABLE vss_db_version (db_version INTEGER);" ,
@@ -52,6 +58,8 @@ const MIGRATIONS: &[&str] = &[
5258 PRIMARY KEY (user_token, store_id, key)
5359 );" ,
5460] ;
61+ #[ cfg( test) ]
62+ const DUMMY_MIGRATION : & str = "SELECT 1 WHERE FALSE;" ;
5563
5664/// The maximum number of key versions that can be returned in a single page.
5765///
@@ -102,6 +110,31 @@ async fn initialize_vss_database(postgres_endpoint: &str, db_name: &str) -> Resu
102110 Ok ( ( ) )
103111}
104112
113+ #[ cfg( test) ]
114+ async fn drop_database ( postgres_endpoint : & str , db_name : & str ) -> Result < ( ) , Error > {
115+ let postgres_dsn = format ! ( "{}/{}" , postgres_endpoint, "postgres" ) ;
116+ let ( client, connection) = tokio_postgres:: connect ( & postgres_dsn, NoTls )
117+ . await
118+ . map_err ( |e| Error :: new ( ErrorKind :: Other , format ! ( "Connection error: {}" , e) ) ) ?;
119+ // Connection must be driven on a separate task, and will resolve when the client is dropped
120+ tokio:: spawn ( async move {
121+ if let Err ( e) = connection. await {
122+ eprintln ! ( "Connection error: {}" , e) ;
123+ }
124+ } ) ;
125+
126+ let drop_database_statement = format ! ( "{} {};" , DROP_DB_CMD , db_name) ;
127+ let num_rows = client. execute ( & drop_database_statement, & [ ] ) . await . map_err ( |e| {
128+ Error :: new (
129+ ErrorKind :: Other ,
130+ format ! ( "Failed to drop database {}: {}" , db_name, e) ,
131+ )
132+ } ) ?;
133+ assert_eq ! ( num_rows, 0 ) ;
134+
135+ Ok ( ( ) )
136+ }
137+
105138impl PostgresBackendImpl {
106139 /// Constructs a [`PostgresBackendImpl`] using `dsn` for PostgreSQL connection information.
107140 pub async fn new ( postgres_endpoint : & str , db_name : & str ) -> Result < Self , Error > {
@@ -125,12 +158,13 @@ impl PostgresBackendImpl {
125158 . map_err ( |e| Error :: new ( ErrorKind :: Other , format ! ( "Failed to build Pool: {}" , e) ) ) ?;
126159 let postgres_backend = PostgresBackendImpl { pool } ;
127160
128- postgres_backend. migrate_vss_database ( ) . await ?;
161+ #[ cfg( not( test) ) ]
162+ postgres_backend. migrate_vss_database ( MIGRATIONS ) . await ?;
129163
130164 Ok ( postgres_backend)
131165 }
132166
133- async fn migrate_vss_database ( & self ) -> Result < ( ) , Error > {
167+ async fn migrate_vss_database ( & self , migrations : & [ & str ] ) -> Result < ( usize , usize ) , Error > {
134168 let mut conn = self . pool . get ( ) . await . map_err ( |e| {
135169 Error :: new (
136170 ErrorKind :: Other ,
@@ -162,16 +196,16 @@ impl PostgresBackendImpl {
162196 . await
163197 . map_err ( |e| Error :: new ( ErrorKind :: Other , format ! ( "Transaction start error: {}" , e) ) ) ?;
164198
165- if migration_start == MIGRATIONS . len ( ) {
199+ if migration_start == migrations . len ( ) {
166200 // No migrations needed, we are done
167- return Ok ( ( ) ) ;
168- } else if migration_start > MIGRATIONS . len ( ) {
201+ return Ok ( ( migration_start , migrations . len ( ) ) ) ;
202+ } else if migration_start > migrations . len ( ) {
169203 panic ! ( "We do not allow downgrades" ) ;
170204 }
171205
172- println ! ( "Applying migration(s) {} through {}" , migration_start, MIGRATIONS . len( ) - 1 ) ;
206+ println ! ( "Applying migration(s) {} through {}" , migration_start, migrations . len( ) - 1 ) ;
173207
174- for ( idx, & stmt) in ( & MIGRATIONS [ migration_start..] ) . iter ( ) . enumerate ( ) {
208+ for ( idx, & stmt) in ( & migrations [ migration_start..] ) . iter ( ) . enumerate ( ) {
175209 let _num_rows = tx. execute ( stmt, & [ ] ) . await . map_err ( |e| {
176210 Error :: new (
177211 ErrorKind :: Other ,
@@ -197,7 +231,7 @@ impl PostgresBackendImpl {
197231 assert_eq ! ( num_rows, 1 , "LOG_MIGRATION_STMT should only add one row at a time" ) ;
198232
199233 let next_migration_start =
200- i32:: try_from ( MIGRATIONS . len ( ) ) . expect ( "Length is definitely smaller than i32::MAX" ) ;
234+ i32:: try_from ( migrations . len ( ) ) . expect ( "Length is definitely smaller than i32::MAX" ) ;
201235 let num_rows =
202236 tx. execute ( UPDATE_VERSION_STMT , & [ & next_migration_start] ) . await . map_err ( |e| {
203237 Error :: new (
@@ -214,7 +248,21 @@ impl PostgresBackendImpl {
214248 Error :: new ( ErrorKind :: Other , format ! ( "Transaction commit error: {}" , e) )
215249 } ) ?;
216250
217- Ok ( ( ) )
251+ Ok ( ( migration_start, migrations. len ( ) ) )
252+ }
253+
254+ #[ cfg( test) ]
255+ async fn get_schema_version ( & self ) -> usize {
256+ let conn = self . pool . get ( ) . await . unwrap ( ) ;
257+ let row = conn. query_one ( GET_VERSION_STMT , & [ ] ) . await . unwrap ( ) ;
258+ usize:: try_from ( row. get :: < & str , i32 > ( DB_VERSION_COLUMN ) ) . unwrap ( )
259+ }
260+
261+ #[ cfg( test) ]
262+ async fn get_upgrades_list ( & self ) -> Vec < usize > {
263+ let conn = self . pool . get ( ) . await . unwrap ( ) ;
264+ let rows = conn. query ( GET_MIGRATION_LOG_STMT , & [ ] ) . await . unwrap ( ) ;
265+ rows. iter ( ) . map ( |row| usize:: try_from ( row. get :: < & str , i32 > ( MIGRATION_LOG_COLUMN ) ) . unwrap ( ) ) . collect ( )
218266 }
219267
220268 fn build_vss_record ( & self , user_token : String , store_id : String , kv : KeyValue ) -> VssDbRecord {
@@ -568,23 +616,104 @@ mod tests {
568616 use crate :: postgres_store:: PostgresBackendImpl ;
569617 use api:: define_kv_store_tests;
570618 use tokio:: sync:: OnceCell ;
619+ use super :: { MIGRATIONS , DUMMY_MIGRATION , drop_database} ;
620+
621+ const POSTGRES_ENDPOINT : & str = "postgresql://postgres:postgres@localhost:5432" ;
622+ const MIGRATIONS_START : usize = 0 ;
623+ const MIGRATIONS_END : usize = MIGRATIONS . len ( ) ;
571624
572625 static START : OnceCell < ( ) > = OnceCell :: const_new ( ) ;
573626
574627 define_kv_store_tests ! ( PostgresKvStoreTest , PostgresBackendImpl , {
628+ let db_name = "postgres_kv_store_tests" ;
575629 START
576630 . get_or_init( || async {
577- // Initialize the database once, and have other threads wait
578- PostgresBackendImpl :: new(
579- "postgresql://postgres:postgres@localhost:5432" ,
580- "postgres" ,
581- )
582- . await
583- . unwrap( ) ;
631+ let _ = drop_database( POSTGRES_ENDPOINT , db_name) . await ;
632+ let store = PostgresBackendImpl :: new( POSTGRES_ENDPOINT , db_name) . await . unwrap( ) ;
633+ let ( start, end) = store. migrate_vss_database( MIGRATIONS ) . await . unwrap( ) ;
634+ assert_eq!( start, MIGRATIONS_START ) ;
635+ assert_eq!( end, MIGRATIONS_END ) ;
584636 } )
585637 . await ;
586- PostgresBackendImpl :: new( "postgresql://postgres:postgres@localhost:5432" , "postgres" )
587- . await
588- . unwrap( )
638+ let store = PostgresBackendImpl :: new( POSTGRES_ENDPOINT , db_name) . await . unwrap( ) ;
639+ let ( start, end) = store. migrate_vss_database( MIGRATIONS ) . await . unwrap( ) ;
640+ assert_eq!( start, MIGRATIONS_END ) ;
641+ assert_eq!( end, MIGRATIONS_END ) ;
642+ assert_eq!( store. get_upgrades_list( ) . await , [ MIGRATIONS_START ] ) ;
643+ assert_eq!( store. get_schema_version( ) . await , MIGRATIONS_END ) ;
644+ store
589645 } ) ;
646+
647+ #[ tokio:: test]
648+ #[ should_panic( expected = "We do not allow downgrades" ) ]
649+ async fn panic_on_downgrade ( ) {
650+ let db_name = "panic_on_downgrade_test" ;
651+ let _ = drop_database ( POSTGRES_ENDPOINT , db_name) . await ;
652+ {
653+ let mut migrations = MIGRATIONS . to_vec ( ) ;
654+ migrations. push ( DUMMY_MIGRATION ) ;
655+ let store = PostgresBackendImpl :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
656+ let ( start, end) = store. migrate_vss_database ( & migrations) . await . unwrap ( ) ;
657+ assert_eq ! ( start, MIGRATIONS_START ) ;
658+ assert_eq ! ( end, MIGRATIONS_END + 1 ) ;
659+ } ;
660+ {
661+ let store = PostgresBackendImpl :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
662+ let _ = store. migrate_vss_database ( MIGRATIONS ) . await . unwrap ( ) ;
663+ } ;
664+ }
665+
666+ #[ tokio:: test]
667+ async fn new_migrations_increments_upgrades ( ) {
668+ let db_name = "new_migrations_increments_upgrades_test" ;
669+ let _ = drop_database ( POSTGRES_ENDPOINT , db_name) . await ;
670+ {
671+ let store = PostgresBackendImpl :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
672+ let ( start, end) = store. migrate_vss_database ( MIGRATIONS ) . await . unwrap ( ) ;
673+ assert_eq ! ( start, MIGRATIONS_START ) ;
674+ assert_eq ! ( end, MIGRATIONS_END ) ;
675+ assert_eq ! ( store. get_upgrades_list( ) . await , [ MIGRATIONS_START ] ) ;
676+ assert_eq ! ( store. get_schema_version( ) . await , MIGRATIONS_END ) ;
677+ } ;
678+ {
679+ let store = PostgresBackendImpl :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
680+ let ( start, end) = store. migrate_vss_database ( MIGRATIONS ) . await . unwrap ( ) ;
681+ assert_eq ! ( start, MIGRATIONS_END ) ;
682+ assert_eq ! ( end, MIGRATIONS_END ) ;
683+ assert_eq ! ( store. get_upgrades_list( ) . await , [ MIGRATIONS_START ] ) ;
684+ assert_eq ! ( store. get_schema_version( ) . await , MIGRATIONS_END ) ;
685+ } ;
686+
687+ let mut migrations = MIGRATIONS . to_vec ( ) ;
688+ migrations. push ( DUMMY_MIGRATION ) ;
689+ {
690+ let store = PostgresBackendImpl :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
691+ let ( start, end) = store. migrate_vss_database ( & migrations) . await . unwrap ( ) ;
692+ assert_eq ! ( start, MIGRATIONS_END ) ;
693+ assert_eq ! ( end, MIGRATIONS_END + 1 ) ;
694+ assert_eq ! ( store. get_upgrades_list( ) . await , [ MIGRATIONS_START , MIGRATIONS_END ] ) ;
695+ assert_eq ! ( store. get_schema_version( ) . await , MIGRATIONS_END + 1 ) ;
696+ } ;
697+
698+ migrations. push ( DUMMY_MIGRATION ) ;
699+ migrations. push ( DUMMY_MIGRATION ) ;
700+ {
701+ let store = PostgresBackendImpl :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
702+ let ( start, end) = store. migrate_vss_database ( & migrations) . await . unwrap ( ) ;
703+ assert_eq ! ( start, MIGRATIONS_END + 1 ) ;
704+ assert_eq ! ( end, MIGRATIONS_END + 3 ) ;
705+ assert_eq ! ( store. get_upgrades_list( ) . await , [ MIGRATIONS_START , MIGRATIONS_END , MIGRATIONS_END + 1 ] ) ;
706+ assert_eq ! ( store. get_schema_version( ) . await , MIGRATIONS_END + 3 ) ;
707+ } ;
708+
709+ {
710+ let store = PostgresBackendImpl :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
711+ let list = store. get_upgrades_list ( ) . await ;
712+ assert_eq ! ( list, [ MIGRATIONS_START , MIGRATIONS_END , MIGRATIONS_END + 1 ] ) ;
713+ let version = store. get_schema_version ( ) . await ;
714+ assert_eq ! ( version, MIGRATIONS_END + 3 ) ;
715+ }
716+
717+ drop_database ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
718+ }
590719}
0 commit comments