Skip to content

Commit de87130

Browse files
committed
Add migrations tests
1 parent eace62d commit de87130

File tree

1 file changed

+145
-19
lines changed

1 file changed

+145
-19
lines changed

rust/impls/src/postgres_store.rs

Lines changed: 145 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,16 @@ const VALUE_COLUMN: &str = "value";
2828
const VERSION_COLUMN: &str = "version";
2929

3030
const DB_VERSION_COLUMN: &str = "db_version";
31+
#[cfg(test)]
32+
const MIGRATION_LOG_COLUMN: &str = "upgrade_from";
3133

3234
const CHECK_DB_STMT: &str = "SELECT 1 FROM pg_database WHERE datname = $1";
3335
const INIT_DB_CMD: &str = "CREATE DATABASE";
3436
const GET_VERSION_STMT: &str = "SELECT db_version FROM vss_db_version;";
3537
const UPDATE_VERSION_STMT: &str = "UPDATE vss_db_version SET db_version=$1;";
3638
const LOG_MIGRATION_STMT: &str = "INSERT INTO vss_db_upgrades VALUES($1);";
39+
#[cfg(test)]
40+
const GET_MIGRATION_LOG_STMT: &str = "SELECT upgrade_from FROM vss_db_upgrades;";
3741

3842
const MIGRATIONS: &[&str] = &[
3943
"CREATE TABLE vss_db_version (db_version INTEGER);",
@@ -102,6 +106,31 @@ async fn initialize_vss_database(postgres_endpoint: &str, db_name: &str) -> Resu
102106
Ok(())
103107
}
104108

109+
#[cfg(test)]
110+
async fn drop_database(postgres_endpoint: &str, db_name: &str) -> Result<(), Error> {
111+
let postgres_dsn = format!("{}/{}", postgres_endpoint, "postgres");
112+
let (client, connection) = tokio_postgres::connect(&postgres_dsn, NoTls)
113+
.await
114+
.map_err(|e| Error::new(ErrorKind::Other, format!("Connection error: {}", e)))?;
115+
// Connection must be driven on a separate task, and will resolve when the client is dropped
116+
tokio::spawn(async move {
117+
if let Err(e) = connection.await {
118+
eprintln!("Connection error: {}", e);
119+
}
120+
});
121+
122+
let drop_database_statement = format!("DROP DATABASE {};", db_name);
123+
let num_rows = client.execute(&drop_database_statement, &[]).await.map_err(|e| {
124+
Error::new(
125+
ErrorKind::Other,
126+
format!("Failed to drop database {}: {}", db_name, e),
127+
)
128+
})?;
129+
assert_eq!(num_rows, 0);
130+
131+
Ok(())
132+
}
133+
105134
impl PostgresBackendImpl {
106135
/// Constructs a [`PostgresBackendImpl`] using `dsn` for PostgreSQL connection information.
107136
pub async fn new(postgres_endpoint: &str, db_name: &str) -> Result<Self, Error> {
@@ -125,12 +154,13 @@ impl PostgresBackendImpl {
125154
.map_err(|e| Error::new(ErrorKind::Other, format!("Failed to build Pool: {}", e)))?;
126155
let postgres_backend = PostgresBackendImpl { pool };
127156

128-
postgres_backend.migrate_vss_database().await?;
157+
#[cfg(not(test))]
158+
postgres_backend.migrate_vss_database(MIGRATIONS).await?;
129159

130160
Ok(postgres_backend)
131161
}
132162

133-
async fn migrate_vss_database(&self) -> Result<(), Error> {
163+
async fn migrate_vss_database(&self, migrations: &[&str]) -> Result<(usize, usize), Error> {
134164
let mut conn = self.pool.get().await.map_err(|e| {
135165
Error::new(
136166
ErrorKind::Other,
@@ -162,16 +192,16 @@ impl PostgresBackendImpl {
162192
.await
163193
.map_err(|e| Error::new(ErrorKind::Other, format!("Transaction start error: {}", e)))?;
164194

165-
if migration_start == MIGRATIONS.len() {
195+
if migration_start == migrations.len() {
166196
// No migrations needed, we are done
167-
return Ok(());
168-
} else if migration_start > MIGRATIONS.len() {
197+
return Ok((migration_start, migrations.len()));
198+
} else if migration_start > migrations.len() {
169199
panic!("We do not allow downgrades");
170200
}
171201

172-
println!("Applying migration(s) {} through {}", migration_start, MIGRATIONS.len() - 1);
202+
println!("Applying migration(s) {} through {}", migration_start, migrations.len() - 1);
173203

174-
for (idx, &stmt) in (&MIGRATIONS[migration_start..]).iter().enumerate() {
204+
for (idx, &stmt) in (&migrations[migration_start..]).iter().enumerate() {
175205
let _num_rows = tx.execute(stmt, &[]).await.map_err(|e| {
176206
Error::new(
177207
ErrorKind::Other,
@@ -197,7 +227,7 @@ impl PostgresBackendImpl {
197227
assert_eq!(num_rows, 1, "LOG_MIGRATION_STMT should only add one row at a time");
198228

199229
let next_migration_start =
200-
i32::try_from(MIGRATIONS.len()).expect("Length is definitely smaller than i32::MAX");
230+
i32::try_from(migrations.len()).expect("Length is definitely smaller than i32::MAX");
201231
let num_rows =
202232
tx.execute(UPDATE_VERSION_STMT, &[&next_migration_start]).await.map_err(|e| {
203233
Error::new(
@@ -214,7 +244,21 @@ impl PostgresBackendImpl {
214244
Error::new(ErrorKind::Other, format!("Transaction commit error: {}", e))
215245
})?;
216246

217-
Ok(())
247+
Ok((migration_start, migrations.len()))
248+
}
249+
250+
#[cfg(test)]
251+
async fn get_schema_version(&self) -> usize {
252+
let conn = self.pool.get().await.unwrap();
253+
let row = conn.query_one(GET_VERSION_STMT, &[]).await.unwrap();
254+
usize::try_from(row.get::<&str, i32>(DB_VERSION_COLUMN)).unwrap()
255+
}
256+
257+
#[cfg(test)]
258+
async fn get_upgrades_list(&self) -> Vec<usize> {
259+
let conn = self.pool.get().await.unwrap();
260+
let rows = conn.query(GET_MIGRATION_LOG_STMT, &[]).await.unwrap();
261+
rows.iter().map(|row| usize::try_from(row.get::<&str, i32>(MIGRATION_LOG_COLUMN)).unwrap()).collect()
218262
}
219263

220264
fn build_vss_record(&self, user_token: String, store_id: String, kv: KeyValue) -> VssDbRecord {
@@ -568,23 +612,105 @@ mod tests {
568612
use crate::postgres_store::PostgresBackendImpl;
569613
use api::define_kv_store_tests;
570614
use tokio::sync::OnceCell;
615+
use super::{MIGRATIONS, drop_database};
616+
617+
const POSTGRES_ENDPOINT: &str = "postgresql://postgres:postgres@localhost:5432";
618+
const MIGRATIONS_START: usize = 0;
619+
const MIGRATIONS_END: usize = MIGRATIONS.len();
571620

572621
static START: OnceCell<()> = OnceCell::const_new();
573622

574623
define_kv_store_tests!(PostgresKvStoreTest, PostgresBackendImpl, {
624+
let db_name = "postgres_kv_store_tests";
575625
START
576626
.get_or_init(|| async {
577-
// Initialize the database once, and have other threads wait
578-
PostgresBackendImpl::new(
579-
"postgresql://postgres:postgres@localhost:5432",
580-
"postgres",
581-
)
582-
.await
583-
.unwrap();
627+
let _ = drop_database(POSTGRES_ENDPOINT, db_name).await;
628+
let store = PostgresBackendImpl::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
629+
let (start, end) = store.migrate_vss_database(MIGRATIONS).await.unwrap();
630+
assert_eq!(start, MIGRATIONS_START);
631+
assert_eq!(end, MIGRATIONS_END);
584632
})
585633
.await;
586-
PostgresBackendImpl::new("postgresql://postgres:postgres@localhost:5432", "postgres")
587-
.await
588-
.unwrap()
634+
let store = PostgresBackendImpl::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
635+
let (start, end) = store.migrate_vss_database(MIGRATIONS).await.unwrap();
636+
assert_eq!(start, MIGRATIONS_END);
637+
assert_eq!(end, MIGRATIONS_END);
638+
assert_eq!(store.get_upgrades_list().await, [MIGRATIONS_START]);
639+
assert_eq!(store.get_schema_version().await, MIGRATIONS_END);
640+
store
589641
});
642+
643+
#[tokio::test]
644+
#[should_panic(expected = "We do not allow downgrades")]
645+
async fn panic_on_downgrade() {
646+
let db_name = "panic_on_downgrade_test";
647+
let _ = drop_database(POSTGRES_ENDPOINT, db_name).await;
648+
{
649+
let mut migrations = MIGRATIONS.to_vec();
650+
migrations.push("SELECT 1 WHERE FALSE;");
651+
let store = PostgresBackendImpl::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
652+
let (start, end) = store.migrate_vss_database(&migrations).await.unwrap();
653+
assert_eq!(start, MIGRATIONS_START);
654+
assert_eq!(end, MIGRATIONS_END + 1);
655+
};
656+
{
657+
let store = PostgresBackendImpl::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
658+
let _ = store.migrate_vss_database(MIGRATIONS).await.unwrap();
659+
};
660+
}
661+
662+
#[tokio::test]
663+
async fn new_migrations_increments_upgrades() {
664+
let db_name = "new_migrations_increments_upgrades_test";
665+
let dummy_migration = "SELECT 1 WHERE FALSE;";
666+
let _ = drop_database(POSTGRES_ENDPOINT, db_name).await;
667+
{
668+
let store = PostgresBackendImpl::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
669+
let (start, end) = store.migrate_vss_database(MIGRATIONS).await.unwrap();
670+
assert_eq!(start, MIGRATIONS_START);
671+
assert_eq!(end, MIGRATIONS_END);
672+
assert_eq!(store.get_upgrades_list().await, [MIGRATIONS_START]);
673+
assert_eq!(store.get_schema_version().await, MIGRATIONS_END);
674+
};
675+
{
676+
let store = PostgresBackendImpl::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
677+
let (start, end) = store.migrate_vss_database(MIGRATIONS).await.unwrap();
678+
assert_eq!(start, MIGRATIONS_END);
679+
assert_eq!(end, MIGRATIONS_END);
680+
assert_eq!(store.get_upgrades_list().await, [MIGRATIONS_START]);
681+
assert_eq!(store.get_schema_version().await, MIGRATIONS_END);
682+
};
683+
684+
let mut migrations = MIGRATIONS.to_vec();
685+
migrations.push(dummy_migration);
686+
{
687+
let store = PostgresBackendImpl::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
688+
let (start, end) = store.migrate_vss_database(&migrations).await.unwrap();
689+
assert_eq!(start, MIGRATIONS_END);
690+
assert_eq!(end, MIGRATIONS_END + 1);
691+
assert_eq!(store.get_upgrades_list().await, [MIGRATIONS_START, MIGRATIONS_END]);
692+
assert_eq!(store.get_schema_version().await, MIGRATIONS_END + 1);
693+
};
694+
695+
migrations.push(dummy_migration);
696+
migrations.push(dummy_migration);
697+
{
698+
let store = PostgresBackendImpl::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
699+
let (start, end) = store.migrate_vss_database(&migrations).await.unwrap();
700+
assert_eq!(start, MIGRATIONS_END + 1);
701+
assert_eq!(end, MIGRATIONS_END + 3);
702+
assert_eq!(store.get_upgrades_list().await, [MIGRATIONS_START, MIGRATIONS_END, MIGRATIONS_END + 1]);
703+
assert_eq!(store.get_schema_version().await, MIGRATIONS_END + 3);
704+
};
705+
706+
{
707+
let store = PostgresBackendImpl::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
708+
let list = store.get_upgrades_list().await;
709+
assert_eq!(list, [MIGRATIONS_START, MIGRATIONS_END, MIGRATIONS_END + 1]);
710+
let version = store.get_schema_version().await;
711+
assert_eq!(version, MIGRATIONS_END + 3);
712+
}
713+
714+
drop_database(POSTGRES_ENDPOINT, db_name).await.unwrap();
715+
}
590716
}

0 commit comments

Comments
 (0)